首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tf-slim中的variables_to_train标志

Tf-slim中的variables_to_train标志
EN

Stack Overflow用户
提问于 2018-02-01 16:20:37
回答 1查看 330关注 0票数 1

我正在使用TF-Slim从预先训练的模型微调我的模型。当我使用create_train_op时,我发现它有一个参数是variables_to_train。在一些教程中,它使用了如下标志:

代码语言:javascript
复制
   all_trainable = [v for v in tf.trainable_variables()]
   trainable     = [v for v in all_trainable]
   train_op      = slim.learning.create_train_op(
        opt,
        global_step=global_step,
        variables_to_train=trainable,
        summarize_gradients=True)

但在官方的TF-Slim中,它没有使用

代码语言:javascript
复制
   all_trainable = [v for v in tf.trainable_variables()]
   trainable     = [v for v in all_trainable]
   train_op      = slim.learning.create_train_op(
        opt,
        global_step=global_step,            
        summarize_gradients=True)

那么,使用和不使用variables_to_train有什么不同呢?

EN

回答 1

Stack Overflow用户

发布于 2018-07-21 17:30:39

你的两个例子都做了同样的事情。您可以训练出现在图中的所有可训练变量。使用参数variables_to_train,您可以定义哪些变量应该在训练期间更新。

这种情况的一个用例是,当你有预先训练好的东西,比如单词嵌入,而你不想在你的模型中训练。使用

代码语言:javascript
复制
train_vars = [v for v in tf.trainable_variables() if "embeddings" not in v.name]
train_op      = slim.learning.create_train_op(
    opt,
    global_step=global_step,
    variables_to_train=train_vars,
    summarize_gradients=True)

您可以从训练中排除名称中包含"embeddings"的所有变量。如果您只是想训练所有变量,则不必定义train_vars,并且可以在没有参数variables_to_train的情况下创建训练op。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48558181

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档