首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法使用tf.stop_gradient

无法使用tf.stop_gradient
EN

Stack Overflow用户
提问于 2019-02-28 15:53:11
回答 1查看 149关注 0票数 0

目前,我正试图了解tf.stop_gradient是如何工作的,为此,我使用了下面这个小代码片段

代码语言:javascript
复制
tf.reset_default_graph()
w1 = tf.get_variable(name = 'w1',initializer=tf.constant(10, dtype=tf.float32))
w2 = tf.get_variable(name = 'w2',initializer=tf.constant(3,dtype=tf.float32), trainable=True)
inter = w1*w2
inter=tf.stop_gradient(inter)
loss = w1*w1 - inter  - 10
opt = tf.train.GradientDescentOptimizer(learning_rate = 0.0001)


gradients = opt.compute_gradients(loss)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(gradients))

错误: TypeError: Fetch参数没有无效类型

如果我使用tf.stop_gradient注释掉这一行代码,代码就会运行良好,并且与预期的一样。请指导我如何使用tf.stop_gradient

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-03-01 04:14:21

你正确地使用了tf.stop_gradient。但是,TensorFlow通过删除指向loss的所有图连接来停止inter的梯度。因此,如果使用Noneopt.compute_gradients计算dLoss/dw2,则会返回[1]

返回None使之明确表示两者之间没有图形连接。

TypeError就是这样出现的(dLoss/dw1没有这个问题)。许多用户(包括我自己)认为这种梯度应该是0而不是None,但是TensorFlow工程师坚持认为这是有意的行为。

幸运的是,有一些解决办法,尝试下面的代码:

代码语言:javascript
复制
import tensorflow as tf

w1 = tf.get_variable(name='w1', initializer=tf.constant(10, dtype=tf.float32))
w2 = tf.get_variable(name='w2', initializer=tf.constant(3, dtype=tf.float32))
inter = w1 * w2
inter = tf.stop_gradient(inter)
loss = w1*w1 - inter - 10
dL_dW = tf.gradients(loss, [w1, w2])
# Replace None gradient with 0 manully
dL_dW = [tf.constant(0) if grad is None else grad for grad in dL_dW]

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dL_dW))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54929570

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档