我正在使用TF-Slim从预先训练的模型微调我的模型。当我使用create_train_op时,我发现它有一个参数是variables_to_train。在一些教程中,它使用了如下标志:
all_trainable = [v for v in tf.trainable_variables()]
trainable = [v for v in all_trainable]
train_op = slim.learning.create_train_op(
opt,
global_step=global_step,
variables_to_train=trainable,
summarize_gradients=True)但在官方的TF-Slim中,它没有使用
all_trainable = [v for v in tf.trainable_variables()]
trainable = [v for v in all_trainable]
train_op = slim.learning.create_train_op(
opt,
global_step=global_step,
summarize_gradients=True)那么,使用和不使用variables_to_train有什么不同呢?
发布于 2018-07-21 17:30:39
你的两个例子都做了同样的事情。您可以训练出现在图中的所有可训练变量。使用参数variables_to_train,您可以定义哪些变量应该在训练期间更新。
这种情况的一个用例是,当你有预先训练好的东西,比如单词嵌入,而你不想在你的模型中训练。使用
train_vars = [v for v in tf.trainable_variables() if "embeddings" not in v.name]
train_op = slim.learning.create_train_op(
opt,
global_step=global_step,
variables_to_train=train_vars,
summarize_gradients=True)您可以从训练中排除名称中包含"embeddings"的所有变量。如果您只是想训练所有变量,则不必定义train_vars,并且可以在没有参数variables_to_train的情况下创建训练op。
https://stackoverflow.com/questions/48558181
复制相似问题