GraphKeys.TRAINABLE_VARIABLES和tf.trainable_variables()是一样的吗?
GraphKeys.TRAINABLE_VARIABLES实际上是tf.GraphKeys.TRAINABLE_VARIABLES吗?
看来网络成功地培训了:
optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.train_op = optimizer.minimize(self.loss, var_list=tf.trainable_variables())但不是和
optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.train_op = optimizer.minimize(self.loss)根据文档
var_list: Optional list or tuple of Variable objects to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.此外,正如我在批处理规范化示例中所看到的,省略了var_list:
x_norm = tf.layers.batch_normalization(x, training=training)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)发布于 2019-04-10 19:02:56
如果不将var_list传递给minimize()函数,则将以以下方式检索变量(取自compute_gradients() 源代码):
if var_list is None:
var_list = (
variables.trainable_variables() +
ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))如果您没有在ResourceVariable中定义任何tf.trainable_variables()实例,那么结果应该是相同的。我猜问题出在别的地方。
您可以尝试在调用minimize()之前执行一些测试,以确保没有不在tf.trainable_variables()中的ResourceVariable
import tensorflow as tf
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, shape=[None, 2])
with tf.name_scope('network'):
logits = tf.layers.dense(x, units=2)
var_list = (tf.trainable_variables()
+ tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
assert set(var_list) == set(tf.trainable_variables())https://stackoverflow.com/questions/55619070
复制相似问题