首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >“深度强化学习:动手”中的Pytorch问题

“深度强化学习:动手”中的Pytorch问题
EN

Stack Overflow用户
提问于 2019-10-02 07:48:55
回答 2查看 56关注 0票数 2

我正在读Maxim Lapan的“深度学习实践”。我在第二章中遇到了这段代码,有几件事我不明白。谁能解释一下为什么print(out)的输出给出了三个参数,而不是我们输入的单个浮点张量。另外,为什么这里需要超级函数呢?最后,forward接受的x参数是什么?谢谢。

代码语言:javascript
复制
class OurModule(nn.Module):
    def __init__(self, num_inputs, num_classes, dropout_prob=0.3):  #init 
        super(OurModule, self).__init__() #Call OurModule and pass the net instance (Why is this necessary?) 
        self.pipe = nn.Sequential( #net.pipe is the nn object now
            nn.Linear(num_inputs, 5),
            nn.ReLU(),
            nn.Linear(5, 20),
            nn.ReLU(),
            nn.Linear(20, num_classes),
            nn.Dropout(p=dropout_prob),
            nn.Softmax(dim=1)
        )

    def forward(self, x): #override the default forward method by passing it our net instance and (return the nn object?). x is the tensor? This is called when 'net' receives a param?
        return self.pipe(x)

if __name__ == "__main__":
    net = OurModule(num_inputs=2, num_classes=3)
    print(net)
    v = torch.FloatTensor([[2, 3]])
    out = net(v)
    print(out) #[2,3] put through the forward method of the nn? Why did we get a third param for the output?
    print("Cuda's availability is %s" % torch.cuda.is_available()) #find if gpu is available
    if torch.cuda.is_available():
        print("Data from cuda: %s" % out.to('cuda'))

OurModule.__mro__
EN

回答 2

Stack Overflow用户

发布于 2019-10-02 08:50:19

OurModule定义了一个PyTorch nn.Module,它接受2输入(num_inputs)并产生3输出(num_classes)。

它由以下部分组成:

  1. 接受2输入并产生5输出
  2. 的层接受5输入并生成D18的ReLU
    1. A Linear层输出H2195>H120接受D25输入并产生D26(D27)的ReLU
    2. A D24层输出H228H129层Dropout

    A Softmax

您将创建由2输入组成的v,并在调用net(v)时通过此网络的forward()方法传递它。然后,将运行此网络的结果(3输出)存储在out中。

在您的示例中,x采用vtorch.FloatTensor([[2, 3]])的值

票数 3
EN

Stack Overflow用户

发布于 2019-10-02 17:28:10

虽然@JoshVarty已经提供了一个很好的答案,但我想补充一点。

为什么这里需要超级函数

OurModule继承了nn.Module。超级函数意味着您想要使用父函数(nn.Module)的函数,即init。您可以参考source code来了解父类在init函数中到底做了什么。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58193626

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档