首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >非常奇怪的tensorflow行为

非常奇怪的tensorflow行为
EN

Stack Overflow用户
提问于 2018-03-23 08:15:48
回答 1查看 79关注 0票数 2

我有非常简单的行,它们会产生非常奇怪的意外行为:

代码语言:javascript
复制
import tensorflow as tf

y = tf.Variable(2, dtype=tf.int32)

a1 = tf.assign(y, y + 1)
a2 = tf.assign(y, y * 2)

with tf.control_dependencies([a1, a2]):
    t = y+0

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(4):
        print('t=%d' % sess.run(t))
        print('y=%d' % sess.run(y))

人们期望的是

代码语言:javascript
复制
t=6
y=6
t=14
y=14
t=30
y=30
t=62
y=62

但第一步,我得到了:

代码语言:javascript
复制
t=6
y=6
t=13
y=13
t=26
y=26
t=27
y=27

第二轮,我有:

代码语言:javascript
复制
t=3
y=3
t=6
y=6
t=14
y=14
t=15
y=15

第三轮,我有:

代码语言:javascript
复制
t=6
y=6
t=14
y=14
t=28
y=28
t=56
y=56

非常可笑,多次运行会产生多个不同的输出序列,很奇怪,有人能帮忙吗?

编辑:更改为

代码语言:javascript
复制
import tensorflow as tf
import os
y = tf.Variable(2, dtype=tf.int32)

a1 = tf.assign(y, y + 1)
a2 = tf.assign(y, y * 2)
a3 = tf.group(a1, a2)
with tf.control_dependencies([a3]):
    t = tf.identity(y+0)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(4):

        print('t=%d' % sess.run(t))
        print('y=%d' % sess.run(y))

...still不能正常工作。

奇怪的是,这个代码:

代码语言:javascript
复制
a1 = tf.assign(y, y + 1)
with tf.control_dependencies([a1]):
  a2 = tf.assign(y, y * 2)
  with tf.control_dependencies([a2]):
    t = tf.identity(y)

..。工作正常,但只需将a2移动到前面的

代码语言:javascript
复制
a1 = tf.assign(y, y + 1)
a2 = tf.assign(y, y * 2)
with tf.control_dependencies([a1]):
  with tf.control_dependencies([a2]):
    t = tf.identity(y)

..。不是的。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-03-23 10:45:11

您的方法的问题是,a1a2的顺序也很重要:您希望在a2之前对a1进行评估。tf.control_dependencies([a1, a2])保证在a1a2之后执行t,但它们本身可以按任何顺序进行计算。

我会像这样明确地依赖:

代码语言:javascript
复制
y = tf.Variable(2, dtype=tf.int32)
a1 = tf.assign(y, y + 1)
with tf.control_dependencies([a1]):
  a2 = tf.assign(y, y * 2)
  with tf.control_dependencies([a2]):
    t = tf.identity(y)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for i in range(4):
    print('t=%d' % sess.run(t))
    print('y=%d' % sess.run(y))

输出:

代码语言:javascript
复制
t=6
y=6
t=14
y=14
t=30
y=30
t=62
y=62
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49445111

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档