首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用不同的可训练变量从Inception-3检查点恢复训练

如何使用不同的可训练变量从Inception-3检查点恢复训练
EN

Stack Overflow用户
提问于 2017-05-06 11:23:45
回答 1查看 573关注 0票数 1

我有一个非常常见的用例,冻结了Inception的底层,只训练了前两层,之后我降低了学习率,并对整个Inception模型进行了微调。

下面是我运行第一部分的代码

代码语言:javascript
复制
train_dir='/home/ubuntu/pynb/TF play/log-inceptionv3flowers'
with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = get_dataset()
    images, _, labels = load_batch(dataset, batch_size=32)

    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v3_arg_scope()):
        logits, _ = inception.inception_v3(images, num_classes=5, is_training=True)

    # Specify the loss function:
    one_hot_labels = slim.one_hot_encoding(labels, 5)
    tf.losses.softmax_cross_entropy(one_hot_labels, logits)
    total_loss = tf.losses.get_total_loss()

    # Create some summaries to visualize the training process:
    tf.summary.scalar('losses/Total Loss', total_loss)

    # Specify the optimizer and create the train op:
    optimizer = tf.train.RMSPropOptimizer(0.001, 0.9,
                                    momentum=0.9, epsilon=1.0)
    train_op = slim.learning.create_train_op(total_loss, optimizer, variables_to_train=get_variables_to_train())

    # Run the training:
    final_loss = slim.learning.train(
        train_op,
        logdir=train_dir,
        init_fn=get_init_fn(),
        number_of_steps=4500,
        save_summaries_secs=30,
        save_interval_secs=30,
        session_config=tf.ConfigProto(gpu_options=gpu_options))

print('Finished training. Last batch loss %f' % final_loss)

,然后是运行第二部分的代码。

代码语言:javascript
复制
train_dir='/home/ubuntu/pynb/TF play/log-inceptionv3flowers'
with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = get_dataset()
    images, _, labels = load_batch(dataset, batch_size=32)

    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v3_arg_scope()):
        logits, _ = inception.inception_v3(images, num_classes=5, is_training=True)

    # Specify the loss function:
    one_hot_labels = slim.one_hot_encoding(labels, 5)
    tf.losses.softmax_cross_entropy(one_hot_labels, logits)
    total_loss = tf.losses.get_total_loss()
    # Create some summaries to visualize the training process:
    tf.summary.scalar('losses/Total Loss', total_loss)

    # Specify the optimizer and create the train op:
    optimizer = tf.train.RMSPropOptimizer(0.0001, 0.9,
                                    momentum=0.9, epsilon=1.0)
    train_op = slim.learning.create_train_op(total_loss, optimizer)

    # Run the training:
    final_loss = slim.learning.train(
        train_op,
        logdir=train_dir,
        init_fn=get_init_fn(),
        number_of_steps=10000,
        save_summaries_secs=30,
        save_interval_secs=30,
        session_config=tf.ConfigProto(gpu_options=gpu_options))

print('Finished training. Last batch loss %f' % final_loss)

请注意,在第二部分中,我没有向create_train_opvariables_to_train参数传递任何内容。

代码语言:javascript
复制
NotFoundError (see above for traceback): Key InceptionV3/Conv2d_4a_3x3/BatchNorm/beta/RMSProp not found in checkpoint
     [[Node: save_1/RestoreV2_49 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_1/Const_0, save_1/RestoreV2_49/tensor_names, save_1/RestoreV2_49/shape_and_slices)]]
     [[Node: save_1/Assign_774/_1550 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_2911_save_1/Assign_774", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

我怀疑它正在寻找InceptionV3/Conv2d_4a_3x3层的RMSProp变量,这是不存在的,因为我没有在前面的检查点中训练该层。我不确定如何实现我想要的,因为我在文档中看不到如何做到这一点的示例。

EN

回答 1

Stack Overflow用户

发布于 2017-05-09 00:48:41

TF Slim支持从变量名不匹配的检查点读取数据,如下所述:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py#L146

您可以指定检查点中的变量名称如何映射到模型中的变量。

我希望这对你有帮助!

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43816394

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档