我试图在TensorFlow 2中训练一个LeNet-5模型,同时用C语言编写的自定义矩阵乘法和Conv2D卷积代替所有的稠密矩阵乘法和Conv2D卷积,更准确地说,我想保留这些梯度,因为它们是默认的,但是使用我实现的这些操作,而不是TensorFlow的默认操作。而且,我不能使用TensorFlow执行自定义卷积和矩阵乘法,我必须遍历C代码,这是通过CTypes调用的。有没有办法这样做呢?
到目前为止,我尝试的是使用TensorFlow的@tf.experimental.dispatch_for_api来调用使用tf.py_function的函数,后者反过来调用我的C代码。然而,似乎在这样做时,梯度会丢失,模型无法被训练。还有别的办法吗?
@tf.experimental.dispatch_for_api(tf.matmul,
{'a': ApproximateTensor},
{'b': ApproximateTensor},
{'a': tf.Tensor , 'b': tf.Tensor },
{'a': ApproximateTensor, 'b': tf.Tensor },
{'a': tf.Tensor , 'b': ApproximateTensor},
{'a': ApproximateTensor, 'b': ApproximateTensor},
)
def custom_matmul(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, output_type=None):
# tf.print('MATMUL')
if not isinstance(a, ApproximateTensor):
a = ApproximateTensor(a)
if not isinstance(b, ApproximateTensor):
b = ApproximateTensor(b)
_, out_size = b.shape
return ApproximateTensor(process.linear(a.values, b.values, tf.zeros(out_size)))然后,process.linear会这样做:
def linear(input, kernel, bias):
global c_approx
def compute(input, kernel, bias):
global c_approx
# Extract dimensions
batch_size, in_size = input.shape
_, out_size = kernel.shape
if batch_size is None:
batch_size = 1
# Create output
output = np.zeros((batch_size, out_size), dtype=np.float32)
# Compute
for b in range(batch_size):
output[b] = c_approx.custom_matmul(input[b], kernel[i]) + bias
return tf.convert_to_tensor(output)
output = tf.py_function(compute, [input, kernel, bias], input.dtype)
# Manually set output dimensions
batch_size, _ = input.shape
_, out_size = kernel.shape
output.set_shape((batch_size, out_size))
return output换句话说,我想要的是与以下代码相反的代码:
@tf.custom_gradient
def custom_conv(x):
def grad_fn(dy):
return dy
return tf.nn.conv2d(x), grad_fn我想重新定义Conv2D,同时保持他的默认梯度。
提前感谢
发布于 2022-11-02 12:37:35
我真的想出了答案,多亏了这个答案:https://stackoverflow.com/a/43952168/9675442
这个想法就是这样做的:
y = tf.matmul(a, b) + tf.stop_gradient(compute(a, b) - tf.matmul(a, b))我希望这会对其他人有所帮助
https://stackoverflow.com/questions/74285505
复制相似问题