首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.cond降低了训练速度

tf.cond降低了训练速度
EN

Stack Overflow用户
提问于 2017-03-30 02:53:54
回答 1查看 873关注 0票数 5

我在使用tensorflow输入管道,如cifar10模型在tensorflow中,并尝试使用tf.cond进行验证,并编写了如下所示

代码语言:javascript
复制
train_data = model.input(istrain=True)
val_data = model.input(istrain=False)

# This selects which stream to use.
select_val = tf.placeholder(dtype=bool,shape=[],name='select_test')
data = tf.cond(
    select_val,
    lambda:val_data,
    lambda:train_data
)

# Here is the model.
loss = ...
train_op = ...
...

with tf.Session():
    ...

如果删除cond,只使用训练数据,速度为4000个样本/s,如果使用上面的代码,速度会下降到2300个样本/s,验证流水线容量设置得很小,因此在GPU中不会占用太多内存。执行验证的频率也很低。我不知道出了什么问题,请帮帮我。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-03-30 03:51:40

tf.cond并不完全懒惰。即使需要的分支不是要执行的分支,cond的任一分支所需的任何操作都将被运行。因此,在您的示例中,每次调用model.input(istrain=True) op时都会执行datamodel.input(istrain=False)。其中之一的结果被忽略了。

cond给出了一个最小的代码示例:

注意,条件执行只适用于在fn1和fn2中定义的操作。考虑以下简单的程序: Z= tf.multiply(a,b)结果= tf.cond(x < y,lambda: tf.add(x,z),lambda: tf.square(y)) 如果x< y,则将执行tf.add操作,而不执行tf.square操作。由于cond的至少一个分支需要z,所以始终无条件地执行tf.mul操作。虽然这种行为与TensorFlow的数据流模型是一致的,但它偶尔也会让一些用户感到惊讶,他们期望使用更懒惰的语义。

还请注意,这意味着如果您的model.input从更大的池中提取了一些数据集(例如,从整个数据集中提取一批数据),每次运行cond时,都会从验证和培训中提取数据,并且只会丢弃一组数据。在某些情况下,这会导致比低效率更严重的问题。例如,如果您只处理一定数量的历元,那么使用此代码,您实际上不会处理该数量的历元,因为数据是被提取的,而不是使用的。

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43107678

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档