在JAX的Quickstart教程中,我发现可以使用以下代码行有效地计算可微函数fun的Hessian矩阵:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))但是,也可以通过计算以下内容来计算Hessian:
def hessian(fun):
return jit(jacrev(jacfwd(fun)))
def hessian(fun):
return jit(jacfwd(jacfwd(fun)))
def hessian(fun):
return jit(jacrev(jacrev(fun)))下面是一个最低限度的工作示例:
import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev
def comp_hessian():
x = jnp.arange(1.0, 4.0)
def sum_logistics(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
def hessian_1(fun):
return jit(jacfwd(jacrev(fun)))
def hessian_2(fun):
return jit(jacrev(jacfwd(fun)))
def hessian_3(fun):
return jit(jacrev(jacrev(fun)))
def hessian_4(fun):
return jit(jacfwd(jacfwd(fun)))
hessian_fn = hessian_1(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_2(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_3(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_4(sum_logistics)
print(hessian_fn(x))
def main():
comp_hessian()
if __name__ == "__main__":
main()我想知道哪一种方法最适合使用,什么时候使用?我还想知道是否可以使用grad()计算Hessian?grad()与jacfwd和jacrev有何不同?
发布于 2022-01-04 13:50:42
您的问题的答案在JAX文档中;例如,请参阅本节:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev
引用关于jacrev和jacfwd的讨论
这两个函数计算相同的值(直到机器数字),但它们的实现不同:
jacfwd使用前向模式自动微分,这对于“高”雅可比矩阵更有效,而jacrev使用反向模式,这对于“宽”雅可比矩阵更有效。对于接近正方形的矩阵,jacfwd可能比jacrev具有优势.
再往下走,
为了实现
,我们可以使用
jacfwd(jacrev(f))或jacrev(jacfwd(f))或两者的任何其他组合。但正反向通常是最有效的。这是因为在内部雅可比计算中,我们通常是把一个函数区分为广义雅可比(也许就像一个损失函数:ℝⁿ→ℝ),而在外部雅可比计算中,我们用平方雅可比(自∇:ℝⁿ→ℝⁿ)来微分一个函数,这就是前向模式获胜的地方。
因为您的函数看起来像:ℝⁿ→ℝ,那么jit(jacfwd(jacrev(fun)))可能是最有效的方法。
至于为什么不能用grad实现一个恒心函数,这是因为grad只是为具有标量输出的函数的导数而设计的。顾名思义,恒河是向量值雅克比的组合,而不是标量梯度的组合。
https://stackoverflow.com/questions/70572362
复制相似问题