首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >sigmoid的Jax - autograd总是返回nan

sigmoid的Jax - autograd总是返回nan
EN

Stack Overflow用户
提问于 2021-07-08 01:45:23
回答 1查看 144关注 0票数 1

我正在尝试微分一个函数,该函数近似包含在2个限制(截断高斯)内的高斯的分数,给定一个偏移的平均值。jnp.grad不让我区分加起来的布尔过滤器(注释行),所以我不得不临时使用一个sigmoid。

然而,现在当截断边界很高时,梯度总是nan,我不明白为什么。

在下面的例子中,我计算了一个均值为0,std=1为0的高斯曲线的梯度,然后用x对其进行平移。

如果我减小边界,那么函数就会按预期运行。但这不是一个解决方案。当边界较高时,belows始终变为1。但是如果是这种情况,x对下面没有影响,那么它对梯度的贡献应该是0而不是nan。但是如果我返回belows[0][0]而不是jnp.mean(filt, axis=0),我仍然会得到nan

有什么想法吗?提前感谢( github上也有一个公开的问题)

代码语言:javascript
复制
import os

from tqdm import tqdm

os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' # Use 8 CPU devices
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
from jax import vmap

from functools import reduce

def sigmoid(x, scale=100):
    return 1 / (1 + jnp.exp(-x*scale))

def above_lower(x, l, scale=100):
    return sigmoid(x - l, scale)

def below_upper(x, u, scale=100):
    return 1 - sigmoid(x - u, scale)

def combine_soft_filters(a):
    return jnp.prod(jnp.stack(a), axis=0)


def fraction_not_truncated(mu, v, limits, stdnorm_samples):
    L = jnp.linalg.cholesky(v)
    y = vmap(lambda x: jnp.dot(L, x))(stdnorm_samples) + mu
    # filt = reduce(jnp.logical_and, [(y[..., i] > l) & (y[..., i] < u) for i, (l, u) in enumerate(limits)])
    aboves = [above_lower(y[..., i], l) for i, (l, u) in enumerate(limits)]
    belows = [below_upper(y[..., i], u) for i, (l, u) in enumerate(limits)]
    filt = combine_soft_filters(aboves+belows)
    return jnp.mean(filt, axis=0)

limits = np.array([
        [0.,1000],
])

stdnorm_samples = np.random.multivariate_normal([0], np.eye(1), size=1000)

def func(x):
    return fraction_not_truncated(jnp.zeros(1)+x, jnp.eye(1), limits, stdnorm_samples)

_x = np.linspace(-2, 2, 500)
gradfunc = jax.grad(func)
vals = [func(x) for x in tqdm(_x)]
grads = [gradfunc(x) for x in tqdm(_x)]
print(vals)
print(grads)
import matplotlib.pyplot as plt
plt.plot(_x, np.asarray(vals))
plt.ylabel('f(x)')
plt.twinx()
plt.plot(_x, np.asarray(grads), c='r')
plt.ylabel("f(x)'")
plt.title('Fraction not truncated')
plt.axhline(0, color='k', alpha=0.2)
plt.xlabel('shift')
plt.tight_layout()
plt.show()

代码语言:javascript
复制
[DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64), DeviceArray(1., dtype=float64)]
[DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64)]
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-08 07:21:05

问题是,您的sigmoid函数的实现方式使得自动确定的梯度对于较大的x负值不稳定

代码语言:javascript
复制
import jax.numpy as jnp
import jax

def sigmoid(x, scale=100):
    return 1 / (1 + jnp.exp(-x*scale))

print(jax.grad(sigmoid)(-1000.0))
# nan

您可以使用jax.make_jaxpr函数来内省由自动确定的渐变生成的操作(注释是我的注释),从而了解发生这种情况的原因:

代码语言:javascript
复制
>>> jax.make_jaxpr(jax.grad(sigmoid))(-1000.0)
{ lambda  ; a.                    # a = -1000
  let b = neg a                   # b = 1000
      c = mul b 100.0             # c = 100,000
      d = exp c                   # d = inf
      e = add d 1.0
      _ = div 1.0 e
      f = integer_pow[ y=-2 ] e   # f = 0
      g = mul 1.0 f               # g = 0
      h = mul g 1.0               # h = 0
      i = neg h                   # i = 0
      j = mul i d                 # j = 0 * inf = NaN
      k = mul j 100.0             # k = NaN
      l = neg k                   # l = NaN
  in (l,) }                       # return NaN

这是64位浮点算术失败的情况之一:它没有处理exp(100000)这样的数字的范围。

那么你能做什么呢?一个重要的选项是使用custom derivative rule来告诉autodiff如何以更稳定的方式处理sigmoid函数。但是,在这种情况下,一个更简单的选择是用在autodiff转换下表现更好的东西来重新表示sigmoid函数。一种选择是:

代码语言:javascript
复制
def sigmoid(x, scale=100):
    return 0.5 * (jnp.tanh(x * scale / 2) + 1)

在脚本中使用此版本可以解决此问题。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68290850

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档