首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow:批处理Norm在is_training = False时中断网络

TensorFlow:批处理Norm在is_training = False时中断网络
EN

Stack Overflow用户
提问于 2017-05-26 23:25:05
回答 1查看 1.4K关注 0票数 6

我试图使用来自TensorFlow-Slim的批处理规范层,如下所示:

代码语言:javascript
复制
net = ...
net = slim.batch_norm(net, scale = True, is_training = self.isTraining,
    updates_collections = None, decay = 0.9)
net = tf.nn.relu(net)
net = ...

我训练的对象是:

代码语言:javascript
复制
self.optimizer = slim.learning.create_train_op(self.model.loss,
    tf.train.MomentumOptimizer(learning_rate = self.learningRate,
    momentum = 0.9, use_nesterov = True)

optimizer = self.sess.run([self.optimizer],
    feed_dict={self.model.isTraining:True})

我用以下内容加载保存的重量:

代码语言:javascript
复制
net = model.Model(sess,width,height,channels,weightDecay)

savedWeightsDir = './savedWeights/'
saver = tf.train.Saver(max_to_keep = 5)
checkpointStr = tf.train.latest_checkpoint(savedWeightsDir)
sess.run(tf.global_variables_initializer())
saver.restore(sess, checkpointStr)
global_step = tf.contrib.framework.get_or_create_global_step()

我推断:

代码语言:javascript
复制
inf = self.sess.run([self.softmax],
    feed_dict = {self.imageBatch:imageBatch,self.isTraining:False})

当然,我遗漏了很多代码,并解释了一些代码,但我认为这就是批处理规范所涉及的全部内容。奇怪的是,如果我设定的是训练:没错,我会得到更好的结果。它是否可能是在权重中加载的东西--也许批处理的范数没有保存?代码中有什么明显的错误吗?谢谢。

EN

回答 1

Stack Overflow用户

发布于 2018-09-08 17:05:03

我刚刚遇到了同样的问题,并找到了解在这里。该问题源于需要更新moving_variance.和moving_meantf.layers.batch_normalization层。

为了正确地做到这一点,在您的情况下,您需要修改您的培训过程如下:

代码语言:javascript
复制
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    self.optimizer = slim.learning.create_train_op(self.model.loss,
      tf.train.MomentumOptimizer(learning_rate = self.learningRate,
      momentum = 0.9, use_nesterov = True)

或者更广泛地说,来自文档

代码语言:javascript
复制
  x_norm = tf.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44211371

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档