首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >torch.nn.Softmax、torch.nn.funtional.softmax、torch.softmax和torch.nn.functional.log_softmax的区别是什么

torch.nn.Softmax、torch.nn.funtional.softmax、torch.softmax和torch.nn.functional.log_softmax的区别是什么
EN

Stack Overflow用户
提问于 2021-09-17 03:08:59
回答 2查看 1.1K关注 0票数 0

我试图查找文档,但找不到任何有关torch.softmax的信息。

torch.nn.Softmax、torch.nn.funtional.softmax、torch.softmax和torch.nn.functional.log_softmax有什么区别?

我们很欣赏这些例子。

EN

回答 2

Stack Overflow用户

发布于 2021-09-17 03:28:09

代码语言:javascript
复制
import torch

x = torch.rand(5)

x1 = torch.nn.Softmax()(x)
x2 = torch.nn.functional.softmax(x)
x3 = torch.nn.functional.log_softmax(x)

print(x1)
print(x2)
print(torch.log(x1))
print(x3)
代码语言:javascript
复制
tensor([0.2740, 0.1955, 0.1519, 0.1758, 0.2029])
tensor([0.2740, 0.1955, 0.1519, 0.1758, 0.2029])
tensor([-1.2946, -1.6323, -1.8847, -1.7386, -1.5952])
tensor([-1.2946, -1.6323, -1.8847, -1.7386, -1.5952])

torch.nn.Softmaxtorch.nn.functional.softmax给出相同的输出,一个是类(pytorch模块),另一个是函数。log_softmax在应用softmax之后应用log。

NLLLoss接受对数概率(log(softmax(X)作为输入。因此,对于NLLLoss,你需要log_softmax,log_softmax在数值上更稳定,通常会产生更好的结果。

票数 2
EN

Stack Overflow用户

发布于 2021-09-17 03:19:59

代码语言:javascript
复制
import torch
import torch.nn as nn


class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.LazyLinear(128)
        self.activation = nn.ReLU()
        self.layer_2 = nn.Linear(128, 10)
        self.output_function = nn.Softmax(dim=1)

    def forward(self, x, softmax="module"):
        y = self.layer_1(x)
        y = self.activation(y)
        y = self.layer_2(y)
        if softmax == "module":
            return self.output_function(y)

        # OR
        if softmax == "torch":
            return torch.softmax(y, dim=1)

        # OR (deprecated)
        if softmax == "functional":
            return nn.functional.softmax(y, dim=1)

        # OR (careful, the reason why the log is there is to ensure
        # numerical stability so you should use torch.exp wisely)
        if softmax == "log":
            return torch.exp(torch.log_softmax(y, dim=1))

        raise ValueError(f"Unknown softmax type {softmax}")


x = torch.rand(2, 2)
net = Network()

for s in ["module", "torch", "log"]:
    print(net(x, softmax=s))

基本上,nn.Softmax()创建了一个模块,所以它返回一个函数,而其他的都是纯函数。

为什么你需要一个?在nn.Softmax的文档中有一个例子

此模块不直接与NLLLoss一起工作,后者希望在Softmax和自身之间计算日志。改为使用LogSoftmax (它更快,并且具有更好的数值属性)。

另请参阅What is the difference between log_softmax and softmax?

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69217305

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档