首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >GraphKeys.TRAINABLE_VARIABLES诉tf.trainable_variables()

GraphKeys.TRAINABLE_VARIABLES诉tf.trainable_variables()
EN

Stack Overflow用户
提问于 2019-04-10 18:24:25
回答 1查看 1K关注 0票数 1

GraphKeys.TRAINABLE_VARIABLEStf.trainable_variables()是一样的吗?

GraphKeys.TRAINABLE_VARIABLES实际上是tf.GraphKeys.TRAINABLE_VARIABLES吗?

看来网络成功地培训了:

代码语言:javascript
复制
optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    self.train_op = optimizer.minimize(self.loss, var_list=tf.trainable_variables())

但不是和

代码语言:javascript
复制
optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    self.train_op = optimizer.minimize(self.loss)

根据文档

代码语言:javascript
复制
var_list: Optional list or tuple of Variable objects to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.

此外,正如我在批处理规范化示例中所看到的,省略了var_list

代码语言:javascript
复制
  x_norm = tf.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-04-10 19:02:56

如果不将var_list传递给minimize()函数,则将以以下方式检索变量(取自compute_gradients() 源代码):

代码语言:javascript
复制
if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))

如果您没有在ResourceVariable中定义任何tf.trainable_variables()实例,那么结果应该是相同的。我猜问题出在别的地方。

您可以尝试在调用minimize()之前执行一些测试,以确保没有不在tf.trainable_variables()中的ResourceVariable

代码语言:javascript
复制
import tensorflow as tf

with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, shape=[None, 2])
    with tf.name_scope('network'):
        logits = tf.layers.dense(x, units=2)

    var_list = (tf.trainable_variables()
                + tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    assert set(var_list) == set(tf.trainable_variables())
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55619070

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档