首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >JAX中Hessian矩阵的高效计算

JAX中Hessian矩阵的高效计算
EN

Stack Overflow用户
提问于 2022-01-03 23:01:55
回答 1查看 1K关注 0票数 3

在JAX的Quickstart教程中,我发现可以使用以下代码行有效地计算可微函数fun的Hessian矩阵:

代码语言:javascript
复制
from jax import jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

但是,也可以通过计算以下内容来计算Hessian:

代码语言:javascript
复制
def hessian(fun):
  return jit(jacrev(jacfwd(fun)))

def hessian(fun):
  return jit(jacfwd(jacfwd(fun)))

def hessian(fun):
  return jit(jacrev(jacrev(fun)))

下面是一个最低限度的工作示例:

代码语言:javascript
复制
import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev

def comp_hessian():

    x = jnp.arange(1.0, 4.0)

    def sum_logistics(x):
        return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

    def hessian_1(fun):
        return jit(jacfwd(jacrev(fun)))

    def hessian_2(fun):
        return jit(jacrev(jacfwd(fun)))

    def hessian_3(fun):
        return jit(jacrev(jacrev(fun)))

    def hessian_4(fun):
        return jit(jacfwd(jacfwd(fun)))

    hessian_fn = hessian_1(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_2(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_3(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_4(sum_logistics)
    print(hessian_fn(x))


def main():
    comp_hessian()


if __name__ == "__main__":
    main()

我想知道哪一种方法最适合使用,什么时候使用?我还想知道是否可以使用grad()计算Hessian?grad()jacfwdjacrev有何不同?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-04 13:50:42

您的问题的答案在JAX文档中;例如,请参阅本节:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev

引用关于jacrevjacfwd的讨论

这两个函数计算相同的值(直到机器数字),但它们的实现不同:jacfwd使用前向模式自动微分,这对于“高”雅可比矩阵更有效,而jacrev使用反向模式,这对于“宽”雅可比矩阵更有效。对于接近正方形的矩阵,jacfwd可能比jacrev具有优势.

再往下走,

为了实现

,我们可以使用jacfwd(jacrev(f))jacrev(jacfwd(f))或两者的任何其他组合。但正反向通常是最有效的。这是因为在内部雅可比计算中,我们通常是把一个函数区分为广义雅可比(也许就像一个损失函数:ℝⁿ→ℝ),而在外部雅可比计算中,我们用平方雅可比(自∇:ℝⁿ→ℝⁿ)来微分一个函数,这就是前向模式获胜的地方。

因为您的函数看起来像:ℝⁿ→ℝ,那么jit(jacfwd(jacrev(fun)))可能是最有效的方法。

至于为什么不能用grad实现一个恒心函数,这是因为grad只是为具有标量输出的函数的导数而设计的。顾名思义,恒河是向量值雅克比的组合,而不是标量梯度的组合。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70572362

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档