我在使用tensorflow输入管道,如cifar10模型在tensorflow中,并尝试使用tf.cond进行验证,并编写了如下所示
train_data = model.input(istrain=True)
val_data = model.input(istrain=False)
# This selects which stream to use.
select_val = tf.placeholder(dtype=bool,shape=[],name='select_test')
data = tf.cond(
select_val,
lambda:val_data,
lambda:train_data
)
# Here is the model.
loss = ...
train_op = ...
...
with tf.Session():
...如果删除cond,只使用训练数据,速度为4000个样本/s,如果使用上面的代码,速度会下降到2300个样本/s,验证流水线容量设置得很小,因此在GPU中不会占用太多内存。执行验证的频率也很低。我不知道出了什么问题,请帮帮我。
发布于 2017-03-30 03:51:40
tf.cond并不完全懒惰。即使需要的分支不是要执行的分支,cond的任一分支所需的任何操作都将被运行。因此,在您的示例中,每次调用model.input(istrain=True) op时都会执行data和model.input(istrain=False)。其中之一的结果被忽略了。
cond给出了一个最小的代码示例:
注意,条件执行只适用于在fn1和fn2中定义的操作。考虑以下简单的程序: Z= tf.multiply(a,b)结果= tf.cond(x < y,lambda: tf.add(x,z),lambda: tf.square(y)) 如果x< y,则将执行tf.add操作,而不执行tf.square操作。由于cond的至少一个分支需要z,所以始终无条件地执行tf.mul操作。虽然这种行为与TensorFlow的数据流模型是一致的,但它偶尔也会让一些用户感到惊讶,他们期望使用更懒惰的语义。
还请注意,这意味着如果您的model.input从更大的池中提取了一些数据集(例如,从整个数据集中提取一批数据),每次运行cond时,都会从验证和培训中提取数据,并且只会丢弃一组数据。在某些情况下,这会导致比低效率更严重的问题。例如,如果您只处理一定数量的历元,那么使用此代码,您实际上不会处理该数量的历元,因为数据是被提取的,而不是使用的。
https://stackoverflow.com/questions/43107678
复制相似问题