首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >AdamOptimizer挑战tf.control_dependencies

AdamOptimizer挑战tf.control_dependencies
EN

Stack Overflow用户
提问于 2019-08-07 00:28:49
回答 2查看 112关注 0票数 2

不知何故,AdamOptimizer挑战了tf.control_dependencies

这是一个测试。我要求TensorFlow做以下工作:

  1. 计算损失
  2. 计算损失
  3. 跑一步亚当

我使用依赖关系来“确保”TF在运行步骤2之后运行步骤3。

如果TensorFlow按照正确的顺序执行这三个步骤,那么步骤1和步骤2的结果应该是相同的。

但事实并非如此。怎么啦?

测试:

代码语言:javascript
复制
import numpy as np
import tensorflow as tf

x = tf.get_variable('x', initializer=np.array([1], dtype=np.float64))
loss = x * x

optim = tf.train.AdamOptimizer(1)

## Control Dependencies ##
with tf.control_dependencies([loss]):
    train_op = optim.minimize(loss)

## Run ##
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    for i in range(1000):
        a = sess.run([loss])
        b = sess.run([loss, train_op])[0]
        print(a, b)
        assert np.allclose(a, b)

结果:

代码语言:javascript
复制
[array([1.])] [2.50003137e-14]
AssertionError

步骤1和步骤2的结果并不相同。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-08-07 01:18:42

根据tf.identity(loss)进行第3步会神奇地解决问题。

怎么一回事??

魔法修复:

代码语言:javascript
复制
import numpy as np
import tensorflow as tf

x = tf.get_variable('x', initializer=np.array([1], dtype=np.float64))
loss = x * x

optim = tf.train.AdamOptimizer(1)

## Control Dependencies ##
loss2 = tf.identity(loss)  # <--- this #
with tf.control_dependencies([loss2]):
    train_op = optim.minimize(loss)

## Run ##
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    for i in range(1000):
        a = sess.run([loss])
        b = sess.run([loss2, train_op])[0]  # <--- loss2
        print(a, b)
        assert np.allclose(a, b)

结果:

代码语言:javascript
复制
[array([1.])] [1.]
[array([2.50003137e-14])] [2.50003137e-14]
[array([0.4489748])] [0.4489748]
...
[array([1.151504e-47])] [1.151504e-47]
[array([4.90468459e-46])] [4.90468459e-46]
票数 0
EN

Stack Overflow用户

发布于 2019-08-07 00:49:04

听起来,您希望sess.run([loss, adam_op])运行loss,然后运行adam_op。唉,sess.run不是那样工作的。以这个简单的例子为例--它打印1.0 1.0,表明set_x op在get_x之前运行。

代码语言:javascript
复制
import tensorflow as tf

var_x = tf.get_variable("x", shape=[], initializer=tf.zeros_initializer())
get_x = var_x.read_value()
set_x = var_x.assign(1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    a, b = sess.run([get_x, set_x])
    print(a, b)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57385532

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档