首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >朱莉娅:合子。@来自Enzyme.autodiff。

朱莉娅:合子。@来自Enzyme.autodiff。
EN

Stack Overflow用户
提问于 2022-02-14 15:19:32
回答 1查看 203关注 0票数 3

给定下面的函数f!

代码语言:javascript
复制
function f!(s::Vector, a::Vector, b::Vector)
  
  s .= a .+ b
  return nothing

end # f!

如何定义Zygote的伴随项?

Enzyme.autodiff(f!, Const, Duplicated(s, dz_ds). Duplicated(a, zero(a)), Duplicated(b, zero(b)))

代码语言:javascript
复制
Zygote.@adjoint f!(s, a, b) = f!(s, a, b), # What would come here ?
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-02-15 10:30:04

能想出办法,在这里分享。

对于给定的函数fooZygote.pullback(foo, args...)返回foo(args...)和反向传递(这允许梯度计算)。

我的目标是告诉Zygote使用Enzyme进行反向传球。

这可以通过Zygote.@adjoint来完成(请参阅更多的这里)。

对于数组值函数,Enzyme要求返回nothing及其结果的变异版本在args中(请参阅更多的这里)。

问题帖子中的函数f!是两个数组之和的Enzyme-compatible版本。

因为f!返回nothing,所以当对传递给我们的一些梯度调用向后传递时,Zygote只返回nothing

一种解决方案是将f!放置在返回数组s的包装器(例如f)中。

并将Zygote.@adjoint定义为f,而不是f!

因此,

代码语言:javascript
复制
function f(a::Vector, b::Vector)

  s = zero(a)
  f!(s, a, b)
  return s

end
代码语言:javascript
复制
function enzyme_back(dzds, a, b)

  s    = zero(a)
  dzda = zero(dzds)
  dzdb = zero(dzds)
  Enzyme.autodiff(
    f!,
    Const,
    Duplicated(s, dzds),
    Duplicated(a, dzda),
    Duplicated(b, dzdb)
  )
  return (dzda, dzdb)

end

代码语言:javascript
复制
Zygote.@adjoint f(a, b) = f(a, b), dzds -> enzyme_back(dzds, a, b)

通知Zygote在向后传递中使用Enzyme

最后,您可以检查是否调用Zygote.gradient

代码语言:javascript
复制
g1(a::Vector, b::Vector) = sum(abs2, a + b)

代码语言:javascript
复制
g2(a::Vector, b::Vector) = sum(abs2, f(a, b))

产生同样的结果。

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71114131

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档