首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如果将MobileNet设置为is_training为false,则is_training不可用

如果将MobileNet设置为is_training为false,则is_training不可用
EN

Stack Overflow用户
提问于 2017-08-24 07:28:13
回答 2查看 1.8K关注 0票数 6

对此问题的更准确的描述是,当MobileNet没有显式地设置为true时,is_training的行为会很糟糕。我指的是TensorFlow在其模型存储库v1.py中提供的v1.py

我就是这样创建网络(phase_train=True)的:

代码语言:javascript
复制
with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=phase_train)):
        features, endpoints = mobilenet_v1.mobilenet_v1(
            inputs=images_placeholder, features_layer_size=features_layer_size, dropout_keep_prob=dropout_keep_prob,
            is_training=phase_train)

我正在训练一个识别网络,在训练的同时,我在LFW上测试。我在训练中得到的成绩随着时间的推移而提高,并且取得了很好的准确性。

在部署之前,我冻结了图表。如果我用is_training=True冻结图形,我在LFW上得到的结果和训练中的结果是一样的。但是如果我设置is_training=False,我就会得到结果,就像网络根本没有训练过一样.

这种行为实际上发生在其他网络中,比如“盗梦空间”。

我倾向于认为我错过了一些非常基本的东西,而且这不是TensorFlow中的一个bug .

任何帮助都将不胜感激。

添加更多代码.

这就是我准备培训的方式:

代码语言:javascript
复制
images_placeholder = tf.placeholder(tf.float32, shape=(None, image_size, image_size, 1), name='input')
labels_placeholder = tf.placeholder(tf.int32, shape=(None))
dropout_placeholder = tf.placeholder_with_default(1.0, shape=(), name='dropout_keep_prob')
phase_train_placeholder = tf.Variable(True, name='phase_train')

global_step = tf.Variable(0, name='global_step', trainable=False)

# build graph

with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=phase_train_placeholder)):
    features, endpoints = mobilenet_v1.mobilenet_v1(
        inputs=images_placeholder, features_layer_size=512, dropout_keep_prob=1.0,
        is_training=phase_train_placeholder)

# loss

logits = slim.fully_connected(inputs=features, num_outputs=train_data.get_class_count(), activation_fn=None,
                              weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                              weights_regularizer=slim.l2_regularizer(scale=0.00005),
                              scope='Logits', reuse=False)

tf.losses.sparse_softmax_cross_entropy(labels=labels_placeholder, logits=logits,
                                       reduction=tf.losses.Reduction.MEAN)

loss = tf.losses.get_total_loss()

# normalize output for inference

embeddings = tf.nn.l2_normalize(features, 1, 1e-10, name='embeddings')

# optimizer

optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss, global_step=global_step)

这是我的火车步骤:

代码语言:javascript
复制
batch_data, batch_labels = train_data.next_batch()
feed_dict = {
    images_placeholder: batch_data,
    labels_placeholder: batch_labels,
    dropout_placeholder: dropout_keep_prob
}
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

我可以为如何冻结图形添加代码,但这并不是必要的。用is_train=false构建图形,加载最新的检查点,并在LWF上运行评估来重现问题就足够了。

更新.

我发现问题是在批处理规范化层。只需将此层设置为is_training=false就可以重现问题。

我在发现以下内容后发现的参考资料:

http://ruishu.io/2016/12/27/batchnorm/

https://github.com/tensorflow/tensorflow/issues/10118

Batch Normalization - Tensorflow

将更新的解决方案,一旦我有一个测试的一个。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-09-11 14:42:07

所以我找到了解决办法。主要使用此参考资料:http://ruishu.io/2016/12/27/batchnorm/

从链接中:

注意:当is_training为True时,需要更新moving_mean和moving_variance,默认情况下,update_ops放在tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op中,例如: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)如果update_ops: update_ops= tf.group(*update_ops) total_loss =control_flow_ops.with_dependencies(更新,total_loss)

并且直截了当地说,而不是像这样创建优化器:

代码语言:javascript
复制
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(total_loss, global_step=global_step)

这样做吧:

代码语言:javascript
复制
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(total_loss, global_step=global_step)

这将解决这个问题。

票数 5
EN

Stack Overflow用户

发布于 2017-08-28 16:02:01

is_training不应该产生这种效果。我需要看到更多的代码来理解正在发生的事情,但是当您将is_training设置为false时,变量名可能不匹配,这可能是因为变量范围重用问题。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45855443

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档