给定下面的函数f!:
function f!(s::Vector, a::Vector, b::Vector)
s .= a .+ b
return nothing
end # f!如何定义Zygote的伴随项?
Enzyme.autodiff(f!, Const, Duplicated(s, dz_ds). Duplicated(a, zero(a)), Duplicated(b, zero(b)))?
Zygote.@adjoint f!(s, a, b) = f!(s, a, b), # What would come here ?发布于 2022-02-15 10:30:04
能想出办法,在这里分享。
对于给定的函数foo,Zygote.pullback(foo, args...)返回foo(args...)和反向传递(这允许梯度计算)。
我的目标是告诉Zygote使用Enzyme进行反向传球。
这可以通过Zygote.@adjoint来完成(请参阅更多的这里)。
对于数组值函数,Enzyme要求返回nothing及其结果的变异版本在args中(请参阅更多的这里)。
问题帖子中的函数f!是两个数组之和的Enzyme-compatible版本。
因为f!返回nothing,所以当对传递给我们的一些梯度调用向后传递时,Zygote只返回nothing。
一种解决方案是将f!放置在返回数组s的包装器(例如f)中。
并将Zygote.@adjoint定义为f,而不是f!。
因此,
function f(a::Vector, b::Vector)
s = zero(a)
f!(s, a, b)
return s
endfunction enzyme_back(dzds, a, b)
s = zero(a)
dzda = zero(dzds)
dzdb = zero(dzds)
Enzyme.autodiff(
f!,
Const,
Duplicated(s, dzds),
Duplicated(a, dzda),
Duplicated(b, dzdb)
)
return (dzda, dzdb)
end和
Zygote.@adjoint f(a, b) = f(a, b), dzds -> enzyme_back(dzds, a, b)通知Zygote在向后传递中使用Enzyme。
最后,您可以检查是否调用Zygote.gradient
g1(a::Vector, b::Vector) = sum(abs2, a + b)或
g2(a::Vector, b::Vector) = sum(abs2, f(a, b))产生同样的结果。
https://stackoverflow.com/questions/71114131
复制相似问题