首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用@tffunction发出Tensorflow2警告

使用@tffunction发出Tensorflow2警告
EN

Stack Overflow用户
提问于 2019-11-21 10:05:47
回答 2查看 26.1K关注 0票数 22

此示例代码来自Tensorflow 2

代码语言:javascript
复制
writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")

@tf.function
def my_func(step):
  with writer.as_default():
    # other model code would go here
    tf.summary.scalar("my_metric", 0.5, step=step)

for step in range(100):
  my_func(step)
  writer.flush()

但这是在发出警告。

警告:tensorflow:过去5次调用中有5次调用触发了tf.function回溯。跟踪非常昂贵,过多的跟踪可能是由于传递python对象而不是张量造成的。此外,tf.function还有experimental_relax_shapes=True选项,它可以放松参数形状,避免不必要的回溯。有关更多细节,请参阅https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_argshttps://www.tensorflow.org/api_docs/python/tf/function

有更好的方法吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-11-21 11:24:57

tf.function有一些“特点”。我强烈建议阅读本文:https://www.tensorflow.org/tutorials/customization/performance

在这种情况下,问题是每次使用不同的输入签名调用时,函数都会被“回溯”(即生成一个新的图)。对于张量,输入签名指的是形状和dtype,但是对于Python数字,每个新值都被解释为“不同的”。在这种情况下,因为使用每次更改的step变量调用函数,所以函数每次都会被回溯。对于“真实”代码(例如,在函数中调用模型),这将是非常慢的。

您可以通过简单地将step转换为张量来修复它,在这种情况下,不同的值将不被算作新的输入签名:

代码语言:javascript
复制
for step in range(100):
    step = tf.convert_to_tensor(step, dtype=tf.int64)
    my_func(step)
    writer.flush()

或者使用tf.range直接获得张量:

代码语言:javascript
复制
for step in tf.range(100):
    step = tf.cast(step, tf.int64)
    my_func(step)
    writer.flush()

这不应该产生警告(而且速度更快)。

票数 37
EN

Stack Overflow用户

发布于 2022-11-30 00:27:03

尝尝这个

我用的是model(x)而不是model.predict(x),它对我有用。

如果您在自定义函数中获得此错误,则为您的函数添加一个固定的shapedtype签名。

代码语言:javascript
复制
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58972225

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档