Tf函数不改变对象的属性
class f:
v = 7
def __call__(self):
self.v = self.v + 1
@tf.function
def call(c):
tf.print(c.v) # always 7
c()
tf.print(c.v) # always 8
c = f()
call(c)
call(c)预期印刷量:7 8 8 9
但相反:7 8 7 8
当我移除@tf.function装饰器时,所有功能都如预期的那样工作。如何使我的函数在@tf.function中按预期工作
发布于 2021-11-29 11:44:28
此行为已记录为这里。
副作用,如打印,附加到列表,和变异的全局,可以在一个函数中意外地行为,有时执行两次或不是全部。它们只在第一次使用一组输入调用函数时发生。之后,将重新执行跟踪的tf.Graph,而不执行Python的一般经验规则是避免在逻辑中依赖于code.The的副作用,而只使用它们来调试跟踪。否则,TensorFlow API(如tf.data、tf.print、tf.summary、tf.Variable.assign和tf.TensorArray )是确保代码在每次调用时由TensorFlow运行时执行的最佳方法。
因此,可以尝试使用tf.Variable查看预期的更改:
import tensorflow as tf
class f:
v = tf.Variable(7)
def __call__(self):
self.v.assign_add(1)
@tf.function
def call(c):
tf.print(c.v) # always 7
c()
tf.print(c.v) # always 8
c = f()
call(c)
call(c)7
8
8
9https://stackoverflow.com/questions/70153423
复制相似问题