首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用估计API更新batch_normalization均值和方差

用估计API更新batch_normalization均值和方差
EN

Stack Overflow用户
提问于 2018-03-10 02:41:46
回答 2查看 826关注 0票数 2

这方面的文档并不是100%清晰的:

注意:培训时,需要更新moving_mean和moving_variance。默认情况下,更新操作放在tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op中。例如:

(见normalization)

这是否意味着保存moving_meanmoving_variance所需的全部内容如下?

代码语言:javascript
复制
def model_fn(features, labels, mode, params):
   training = mode == tf.estimator.ModeKeys.TRAIN
   extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

   x = tf.reshape(features, [-1, 64, 64, 3])
   x = tf.layers.batch_normalization(x, training=training)

   # ...

  with tf.control_dependencies(extra_update_ops):
     train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

换句话说,简单地使用

代码语言:javascript
复制
with tf.control_dependencies(extra_update_ops):

注意保存moving_meanmoving_variance

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-03-16 23:02:28

事实证明,这些值可以自动保存。边缘情况是,如果在将批处理规范化op添加到图形之前获得update ops集合,则update集合将为空。这以前没有被记录过,但现在是。

使用batch_norm时要注意的是在调用tf.layers.batch_normalization之后调用tf.layers.batch_normalization

票数 1
EN

Stack Overflow用户

发布于 2018-03-12 16:16:27

是的,添加这些控件依赖将节省平均值和方差。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49204810

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档