首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >故障恢复校验点TensorFlow网

故障恢复校验点TensorFlow网
EN

Stack Overflow用户
提问于 2016-08-19 09:01:48
回答 1查看 304关注 0票数 0

我已经建立了一个自动编码器,以“转换”从VGG19.relu4_1激活为像素。我在tensorflow.contrib.layers中使用了新的方便函数(如TF 0.10rc0)。代码具有与TensorFlow的CIFAR10教程类似的布局,它有一个train.py,用于训练和检查模型到磁盘,还有一个eval.py,用于轮询新的检查点文件并在它们上运行推断。

我的问题是,评估从来没有培训那么好,无论是在损失函数的值方面,还是在我查看输出图像时(即使在与培训相同的图像上运行)。这让我觉得这与恢复过程有关。

当我查看TensorBoard培训的输出时,它看起来很好(最终),所以我不认为我的网络本身有什么问题。

我的网子是这样的:

代码语言:javascript
复制
import tensorflow.contrib.layers as contrib
bn_params = {                                                                             
    "is_training": is_training,
    "center": True,
    "scale": True
}                                                                                                                                                       

tensor = contrib.convolution2d_transpose(vgg_output, 64*4, 4,                               
    stride=2,
    normalizer_fn=contrib.batch_norm,
    normalizer_params=bn_params,
    scope="deconv1")                                                        
tensor = contrib.convolution2d_transpose(tensor, 64*2, 4,                               
    stride=2,
    normalizer_fn=contrib.batch_norm,
    normalizer_params=bn_params,
    scope="deconv2")
.
.
.

train.py中,我这样做是为了保存检查点:

代码语言:javascript
复制
variable_averages = tf.train.ExponentialMovingAverage(mynet.MOVING_AVERAGE_DECAY)
variables_averages_op = variable_averages.apply(tf.trainable_variables())

with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
    train_op = tf.no_op(name='train')

while training:
    # train (with batch normalization's is_training = True)
    if time_to_checkpoint:
        saver.save(sess, checkpoint_path, global_step=step)

eval.py中,我这样做:

代码语言:javascript
复制
# run code that creates the net

variable_averages = tf.train.ExponentialMovingAverage(
                  mynet.MOVING_AVERAGE_DECAY)
saver = tf.train.Saver(variable_averages.variables_to_restore())

while polling:
    # sleep and check for new checkpoint files
    with tf.Session() as sess:
        init = tf.initialize_all_variables()
        init_local = tf.initialize_local_variables()
        sess.run([init, init_local])
        saver.restore(sess, checkpoint_path)

        # run inference (with batch normalization's is_training = False)

蓝色是训练的损失,橙色是早期的损失。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-08-22 07:29:37

问题是我直接使用了tf.train.AdamOptimizer()。在优化过程中,它没有调用contrib.batch_norm中定义的操作来计算输入的运行均值/方差,所以均值/方差总是0.0/1.0。

解决方案是向GraphKeys.UPDATE_OPS集合添加依赖项。已经在contrib模块中定义了一个函数来执行此操作(optimize_loss())

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/39035041

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档