首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Tensorflow 2中结合使用GradientTape和AutoGraph?

如何在Tensorflow 2中结合使用GradientTape和AutoGraph?
EN

Stack Overflow用户
提问于 2019-09-07 04:49:12
回答 1查看 531关注 0票数 2

我不知道如何在TensorFlow2中的AutoGraph上运行GradientTape代码。

我想在TPU上运行GradientTape代码。我想从在CPU上测试它开始。使用AutoGraph,TPU代码的运行速度会快得多。我尝试观察输入变量,并尝试将参数传递给包含GradientTape的函数,但都失败了。

我在这里做了一个可重现的例子:https://colab.research.google.com/drive/1luCk7t5SOcDHQC6YTzJzpSixIqHy_76b#scrollTo=9OQUYQtTTYIt

代码和相应的输出如下所示:它们都以import tensorflow as tf开头

代码语言:javascript
复制
x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x
dy_dx = g.gradient(y, x)
print(dy_dx)

Output:tf.Tensor(6.0, shape=(), dtype=float32)解释:使用紧急执行,GradientTape生成渐变。

代码语言:javascript
复制
@tf.function
def compute_me():
    x = tf.constant(3.0)
    with tf.GradientTape() as g:
      g.watch(x)
      y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    print(dy_dx)
compute_me()

输出:Tensor("AddN:0", shape=(), dtype=float32)解释:在TF2中对GradientTape使用AutoGraph会导致空渐变

代码语言:javascript
复制
@tf.function
def compute_me_args(x):
    with tf.GradientTape() as g:
      g.watch(x)
      y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    print(dy_dx)    
x = tf.constant(3.0)
compute_me_args(x)

输出:Tensor("AddN:0", shape=(), dtype=float32)解释:传入参数也会失败

我原以为所有的单元格都会输出tf.Tensor(6.0, shape=(), dtype=float32),但实际上这些单元格使用的是AutoGraph output Tensor("AddN:0", shape=(), dtype=float32)

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-09-07 05:36:23

它不是“失败”,只是如果在tf.function的上下文中使用(即在图形模式中),print将打印符号张量,而这些张量没有值。试着这样做:

代码语言:javascript
复制
@tf.function
def compute_me():
    x = tf.constant(3.0)
    with tf.GradientTape() as g:
        g.watch(x)
        y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    tf.print(dy_dx)
compute_me()

这应该会打印6。您所需要做的就是使用tf.print,它足够“智能”来打印实际值(如果可用)。或者,使用返回值:

代码语言:javascript
复制
@tf.function
def compute_me():
    x = tf.constant(3.0)
    with tf.GradientTape() as g:
        g.watch(x)
        y = x * x
    dy_dx = g.gradient(y, x) # Will compute to 6.0
    return dy_dx
result = compute_me()
print(result)

输出类似于<tf.Tensor: id=43, shape=(), dtype=float32, numpy=6.0>的内容。您可以看到值(6.0)在这里也是可见的。使用result.numpy()只需获取6.0

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57828416

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档