这方面的文档并不是100%清晰的:
注意:培训时,需要更新moving_mean和moving_variance。默认情况下,更新操作放在tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op中。例如:
这是否意味着保存moving_mean和moving_variance所需的全部内容如下?
def model_fn(features, labels, mode, params):
training = mode == tf.estimator.ModeKeys.TRAIN
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
x = tf.reshape(features, [-1, 64, 64, 3])
x = tf.layers.batch_normalization(x, training=training)
# ...
with tf.control_dependencies(extra_update_ops):
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())换句话说,简单地使用
with tf.control_dependencies(extra_update_ops):注意保存moving_mean和moving_variance
发布于 2018-03-16 23:02:28
事实证明,这些值可以自动保存。边缘情况是,如果在将批处理规范化op添加到图形之前获得update ops集合,则update集合将为空。这以前没有被记录过,但现在是。
使用batch_norm时要注意的是在调用tf.layers.batch_normalization之后调用tf.layers.batch_normalization。
发布于 2018-03-12 16:16:27
是的,添加这些控件依赖将节省平均值和方差。
https://stackoverflow.com/questions/49204810
复制相似问题