首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >数值稳定softmax

数值稳定softmax
EN

Stack Overflow用户
提问于 2017-03-04 18:11:08
回答 4查看 43.1K关注 0票数 35

下面有一种数值稳定的方法来计算softmax函数吗?我得到的值在神经网络代码中变成了Nans。

代码语言:javascript
复制
np.exp(x)/np.sum(np.exp(y))
EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2017-03-05 09:30:50

softmax exp(x)/sum(exp(x))实际上在数字上表现良好.它只有正项,所以我们不必担心重要性的丧失,分母至少和分子一样大,所以结果保证在0到1之间。

唯一可能发生的意外是指数中的过流或流下。x的所有元素的单一或下流溢出会使输出或多或少变得无用。

但是,很容易通过使用对任何标量c都成立的恒等式softmax(x) = softmax(x + c)来防范这种情况:从x中减去max(x)会留下一个只有非正项的向量,排除溢出,并且至少有一个元素不排除消失的分母(某些条目中的潜流是无害的,但并非所有条目都是无害的)。

注脚:理论上,灾难性事故的总和是可能的,但你需要一个荒谬的数目的条件。例如,即使使用16位浮点数,它只能解析3个小数--相对于“正常”64位浮点数的15个小数--我们也需要在2^1431 (~6×10^431)和2^1432之间得到一个和,即减2倍

票数 73
EN

Stack Overflow用户

发布于 2018-03-10 18:39:38

软件最大函数容易出现两个问题:溢出下溢

溢出:当非常大的数字近似为infinity时发生

底流:当非常小的数字(在数字线上接近于零)被近似(即四舍五入)为zero时发生的。

为了解决这些问题,在进行softmax计算时,一个常见的技巧是从所有元素中减去输入向量中的最大元素来移动输入向量。对于输入向量x,定义z以便:

代码语言:javascript
复制
z = x-max(x)

然后取新的(稳定)向量z的最大软件

示例:

代码语言:javascript
复制
def stable_softmax(x):
    z = x - max(x)
    numerator = np.exp(z)
    denominator = np.sum(numerator)
    softmax = numerator/denominator

    return softmax

# input vector
In [267]: vec = np.array([1, 2, 3, 4, 5])
In [268]: stable_softmax(vec)
Out[268]: array([ 0.01165623,  0.03168492,  0.08612854,  0.23412166,  0.63640865])

# input vector with really large number, prone to overflow issue
In [269]: vec = np.array([12345, 67890, 99999999])
In [270]: stable_softmax(vec)
Out[270]: array([ 0.,  0.,  1.])

在上述情况下,我们使用stable_softmax()安全地避免了溢出问题。

有关更多细节,请参见数值计算书中的章节深度学习

票数 44
EN

Stack Overflow用户

发布于 2019-11-29 23:27:53

扩展@kmario23 23的答案,以支持1或2维的numpy数组或列表。如果您要通过softmax传递一批结果,那么2D张量(假设第一个维度是批处理维度)是常见的:

代码语言:javascript
复制
import numpy as np


def stable_softmax(x):
    z = x - np.max(x, axis=-1, keepdims=True)
    numerator = np.exp(z)
    denominator = np.sum(numerator, axis=-1, keepdims=True)
    softmax = numerator / denominator
    return softmax


test1 = np.array([12345, 67890, 99999999])  # 1D numpy
test2 = np.array([[12345, 67890, 99999999], # 2D numpy
                  [123, 678, 88888888]])    #
test3 = [12345, 67890, 999999999]           # 1D list
test4 = [[12345, 67890, 999999999]]         # 2D list

print(stable_softmax(test1))
print(stable_softmax(test2))
print(stable_softmax(test3))
print(stable_softmax(test4))

 [0. 0. 1.]

[[0. 0. 1.]
 [0. 0. 1.]]

 [0. 0. 1.]

[[0. 0. 1.]]
票数 9
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/42599498

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档