我有一个定制的tf.keras.layers.Layer,它只使用TF运算符进行某种位解包装(将整数转换为布尔值(0或1浮点数))。
class CharUnpack(keras.layers.Layer):
def __init__(self, name="CharUnpack", *args, **kwargs):
super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
# Range [7, 6, ..., 0] to bit-shift integers
self._shifting_range = tf.reshape(
tf.dtypes.cast(
tf.range(7, -1, -1, name='shifter_range'),
tf.uint8,
name='shifter_cast'),
(1, 1, 8),
name='shifter_reshape')
# Constant value 0b00000001 to use as bitwise and operator
self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')
def call(self, inputs):
return tf.dtypes.cast(
tf.reshape(
tf.bitwise.bitwise_and(
tf.bitwise.right_shift(
tf.expand_dims(inputs, 2),
self._shifting_range,
),
self._selection_bit,
),
[x if x else -1 for x in self.compute_output_shape(inputs.shape)]
),
tf.float32
)
def compute_output_shape(self, input_shape):
try:
if len(input_shape) > 1:
output_shape = tf.TensorShape(tuple(list(input_shape[:-1]) + [input_shape[-1] * 8]))
else:
output_shape = tf.TensorShape((input_shape[0] * 8,))
except TypeError:
output_shape = input_shape
return output_shape
def compute_output_signature(self, input_signature):
return tf.TensorSpec(self.compute_output_shape(input_signature.shape), tf.float32)我试图对这个层进行基准测试,以提高时间性能,如这个TF导轨所示。
inputs = tf.zeros([64, 400], dtype=tf.uint8)
eager = CharUnpack()
@tf.function
def fun(x):
eager(x)
# Warm-up
eager(inputs)
fun(inputs)
print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))Function: 0.01062483999885444
Eager: 0.12658399900101358如你所见,我可以提速10倍!因此,我将@tf.function装饰器添加到我的CharUnpack.call方法中:
+ @tf.function
def call(self, inputs):
return tf.dtypes.cast(现在,我希望eager和fun都会调用类似的时间,但我没有得到任何改进。
Function: 0.009667591999459546
Eager: 0.10346330100037449此外,在本所以回答第2.1节中指出,模型默认是图形编译的(这应该是逻辑的),但情况似乎并非如此.
如何正确地使用@tf.function装饰器使我的层始终是图形编译的?
发布于 2022-08-14 11:37:46
tldr:fun()不返回任何东西,tensorflow签名非常聪明,可以实现这一点,并且忽略了在fun()中发生的所有TF计算,而eager(x) 必须执行call()函数中正在发生的事情。这就是为什么fun()的执行时间少得离谱。至少这是我认为正在发生的事情--我不是AutoGraph专家,所以如果我出了什么问题,其他人可能会纠正我。
问题投资
在我们投入之前,让我们把事情简化一下。首先,我对您的原始代码进行了如下修改。让我们增加数据的大小,以确保所涉及的数据处理数量足够多,并且数据传输和其他开销并不是分析的主导因素。
inputs = tf.zeros([8192, 400], dtype=tf.uint8)其次,我去掉了一些计算,例如compute_output_shape(),并固定了形状。还在call()中引入了一些张量定义。因此call()负责计算端到端的变量定义。
class CharUnpack(tf.keras.layers.Layer):
def __init__(self, name="CharUnpack", *args, **kwargs):
super(CharUnpack, self).__init__(trainable=False, name=name, *args, **kwargs)
self._shifting_range = None
self._selection_bit = None
@tf.function
def call(self, inputs):
if not self._shifting_range:
# Range [7, 6, ..., 0] to bit-shift integers
self._shifting_range = tf.reshape(
tf.dtypes.cast(
tf.range(7, -1, -1, name='shifter_range'),
tf.uint8,
name='shifter_cast'
),
(1, 1, 8),
name='shifter_reshape')
if not self._selection_bit:
# Constant value 0b00000001 to use as bitwise and operator
self._selection_bit = tf.constant(0x01, dtype=tf.uint8, name='and_selection_bit')
return tf.dtypes.cast(
tf.reshape(
tf.bitwise.bitwise_and(
tf.bitwise.right_shift(
tf.expand_dims(inputs, 2),
self._shifting_range,
),
self._selection_bit,
),
[x if x else -1 for x in self.compute_output_shape(inputs.shape)]
),
tf.float32
)
def compute_output_shape(self, input_shape):
return [8192, 3200]第三,我在time at操作中设置了number=1,以确保每次只分析一个调用。这使我们更容易理解。
# The very first call of either approach
print("Eager:", timeit.timeit(lambda: eager(inputs), number=1))
print("Function:", timeit.timeit(lambda: fun(inputs), number=1))
# The second call
print("Eager:", timeit.timeit(lambda: eager(inputs), number=1))
print("Function:", timeit.timeit(lambda: fun(inputs), number=1))首先,让我们看一下eager()的具体功能
eager_concrete = eager.call.get_concrete_function(tf.TensorSpec(shape=[None, 400], dtype=tf.uint8))
print(eager_concrete)这给了,
ConcreteFunction call(inputs)
Args:
inputs: uint8 Tensor, shape=(None, 400)
Returns:
float32 Tensor, shape=(8192, 3200)让我们看一下fun()的具体功能
fun_concrete = fun.get_concrete_function(tf.TensorSpec(shape=[None, 400], dtype=tf.uint8))
print(fun_concrete)这给了,
ConcreteFunction fun(x)
Args:
x: uint8 Tensor, shape=(None, 400)
Returns:
NoneTensorSpec()因此,您马上就会发现,fun()没有返回任何东西,这应该会在您的脑海中引起注意。更进一步,我们可以看一看签名的追踪所产生的图形。
graph = fun_concrete.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')输出,
[] -> x
['x'] -> CharUnpack/StatefulPartitionedCall接下来,如果您对eager()做同样的操作,您将看到下面列出的所有基本TF操作。
[] -> inputs
[] -> StringFormat
['StringFormat'] -> PrintV2
[] -> shifter_range/start
...
['Reshape'] -> Cast
['Cast', '^NoOp'] -> Identity我们甚至可以查看生成的代码。
print(tf.autograph.to_code(fun.python_function))这给了,
def tf__fun(x):
with ag__.FunctionScope('fun', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
out = ag__.converted_call(ag__.ld(eager), (ag__.ld(x),), None, fscope)所以看一下代码,它所做的就是为eager和x生成一个
我不是AutoGraph专家,但我想它所做的只是将给定的输入x传递给eager.call(),并跳过所有的计算。因此,fun()只是跳过了eager.call()函数中所有重要的计算。
我们如何使fun()实际进行计算?
只需将return语句添加到fun()中即可。
@tf.function
def fun(x):
out = eager(x)
return out这给了,
Eager: 0.6245606249999582
Function: 0.3163724480000383
Eager: 0.2076279070001874
Function: 0.22467646699988109
Eager: 0.25076841500003866
Function: 0.240701412999897现在我们可以看到,eager.call()和fun()的时间是相同的。
从TF文件上说,
除了tf.Variables之外,tf.function必须返回其所有输出。
虽然这一节突出了问题的另一面,但它可能(间接)与这里发生的事情有关。
发布于 2022-05-24 10:30:48
首先,tf.function不需要嵌套使用,也就是说,您只能包装您的自定义train_step() (包含传播)。在这种情况下,不需要包装内部层或子模型的call()函数,因为它们涉及到您的train_step。嵌套使用可能会导致某些意外的性能下降。
第二,任何计算加速都要付出代价,tf.function是一种以时间交换空间的方法,需要初始化来构建Graph。因此,在基准测试时,我们应该多次重新运行相同的函数,因为tf.function的辅助调用不会花费构建时间,只要Tracing没有任何改变。
for _ in range(5):
print("Function:", timeit.timeit(lambda: fun(inputs), number=100))
for _ in range(5):
print("Eager:", timeit.timeit(lambda: eager(inputs), number=100))
# Function: 0.02040819999820087
# Function: 0.020311099986429326
# Function: 0.020155799997155555
# Function: 0.02004839999426622
# Function: 0.019999900003313087
# Eager: 0.035980800006655045
# Eager: 0.035652499995194376
# Eager: 0.035596200003055856
# Eager: 0.03490520000923425
# Eager: 0.03762050000659656https://stackoverflow.com/questions/58233805
复制相似问题