此示例代码来自Tensorflow 2
writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")
@tf.function
def my_func(step):
with writer.as_default():
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
for step in range(100):
my_func(step)
writer.flush()但这是在发出警告。
警告:tensorflow:过去5次调用中有5次调用触发了tf.function回溯。跟踪非常昂贵,过多的跟踪可能是由于传递python对象而不是张量造成的。此外,tf.function还有experimental_relax_shapes=True选项,它可以放松参数形状,避免不必要的回溯。有关更多细节,请参阅https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args和https://www.tensorflow.org/api_docs/python/tf/function。
有更好的方法吗?
发布于 2019-11-21 11:24:57
tf.function有一些“特点”。我强烈建议阅读本文:https://www.tensorflow.org/tutorials/customization/performance
在这种情况下,问题是每次使用不同的输入签名调用时,函数都会被“回溯”(即生成一个新的图)。对于张量,输入签名指的是形状和dtype,但是对于Python数字,每个新值都被解释为“不同的”。在这种情况下,因为使用每次更改的step变量调用函数,所以函数每次都会被回溯。对于“真实”代码(例如,在函数中调用模型),这将是非常慢的。
您可以通过简单地将step转换为张量来修复它,在这种情况下,不同的值将不被算作新的输入签名:
for step in range(100):
step = tf.convert_to_tensor(step, dtype=tf.int64)
my_func(step)
writer.flush()或者使用tf.range直接获得张量:
for step in tf.range(100):
step = tf.cast(step, tf.int64)
my_func(step)
writer.flush()这不应该产生警告(而且速度更快)。
发布于 2022-11-30 00:27:03
尝尝这个
我用的是model(x)而不是model.predict(x),它对我有用。
如果您在自定义函数中获得此错误,则为您的函数添加一个固定的shape和dtype签名。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
...https://stackoverflow.com/questions/58972225
复制相似问题