首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法使用香草Tensorflow代码实现与使用TF-slim的相同性能训练

无法使用香草Tensorflow代码实现与使用TF-slim的相同性能训练
EN

Stack Overflow用户
提问于 2018-06-17 23:07:58
回答 1查看 234关注 0票数 0

下面的代码使用TF库加载模型并完成它在分类任务中的性能达到90% (我省略了加载数据和预处理):

代码语言:javascript
复制
with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=0.0001)):
    logits, _ = resnet_v1.resnet_v1_50(images, num_classes=dataset.num_classes, is_training=True)

one_hot_labels = slim.one_hot_encoding(labels, NUM_CLASSES)
tf.losses.softmax_cross_entropy(one_hot_labels, logits)
total_loss = tf.losses.get_total_loss()
global_step = variables.get_or_create_global_step()
lr = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, GAMMA)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
train_op = slim.learning.create_train_op(total_loss, optimizer, global_step=global_step)
init_fn = slim.assign_from_checkpoint_fn("resnet_v1_50.ckpt", VARIABLES_TO_RESTORE)

final_loss = slim.learning.train( train_op, logdir=train_dir, log_every_n_steps=500, save_summaries_secs=25,  init_fn=init_fn, number_of_steps = NUM_STEPS)

我尝试使用vanilla重写相同的代码,以便更好地控制培训过程,而且由于某些原因,当使用所有相同的超参数(大写)和相同的预处理时,我无法达到相同的性能(性能下降10%)。不同之处在于图的定义:

代码语言:javascript
复制
        lr = tf.train.exponential_decay(LEARNING_RATE,  global_step, DECAY_STEPS, GAMMA)
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
        full_train_op = optimizer.minimize(total_loss, global_step=global_step)

和培训:

代码语言:javascript
复制
for s in range(NUM_STEPS):
    sess.run(train_init_op) #Initializes dataset iterator
    while True:
        try:
            sess.run([full_train_op], feed_dict={is_training: True})                    
        except tf.errors.OutOfRangeError:
            break

超薄的火车功能还在做其他的操作吗?我认为它可能是使用批处理规范化,或者是在代码的版本上没有实现的其他东西。

是否可以在tensorflow中加载瘦resnet模型,并在没有超薄火车功能的情况下对其进行训练?我对重写train_step_fn不感兴趣。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-06-18 03:56:03

这可能是由于没有运行与resnet的批处理规范相关联的update_ops

代码语言:javascript
复制
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
with tf.control_dependencies(update_ops):
    full_train_op = optimizer.minimize(total_loss, global_step)
# same training loop
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50901159

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档