首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >扩展PyTorch nn.Sequential类

扩展PyTorch nn.Sequential类
EN

Stack Overflow用户
提问于 2020-06-18 05:46:45
回答 1查看 397关注 0票数 0

我对Python中的面向对象编程非常陌生,而且一般都是生疏的。我想以这样一种方式扩展PyTorch的'nn.Sequential‘对象,传递给它一个包含每一层中的节点数的元组,根据这些节点自动生成一个OrderedDict。下面是一个功能示例:

代码语言:javascript
复制
layers = (784, 392, 196, 98, 10)
n_layers = len(layers)
modules = OrderedDict()

# Layer definitions for inner layers:
for i in range(n_layers - 2):
    modules[f'fc{i}']   = nn.Linear(layers[i], layers[i+1])
    modules[f'relu{i}'] = nn.ReLU()

# Definition for output layer:
modules['fc_out'] = nn.Linear(layers[-2], layers[-1])
modules['smax_out'] = nn.LogSoftmax(dim=1)

# Define model and check attributes:
model = nn.Sequential(modules)

因此,我不想在初始化nn.Sequential时传递'OrderedDict‘对象,而是希望我的类接受元组。

代码语言:javascript
复制
class Network(nn.Sequential):
   def__init__(self, n_nodes):
      super().__init__()

      **** INSERT LOGIC FROM LAST SNIPPET ***

所以这似乎行不通,因为当我的Network类调用super().__init__()时,它会想要查看层激活的字典。我该如何着手编写我自己的网络,让它绕过这个问题,同时仍然拥有PyTorche的顺序对象的所有功能?

我是这样想的:

代码语言:javascript
复制
class Network(nn.Sequential):
    def __init__(self, layers):
        super().__init__(self.init_modules(layers))


    def init_modules(self, layers):
        n_layers = len(layers)
        modules = OrderedDict()

        # Layer definitions for inner layers:
        for i in range(n_layers - 2):
            modules[f'fc{i}']   = nn.Linear(layers[i], layers[i+1])
            modules[f'relu{i}'] = nn.ReLU()

        # Definition for output layer:
        modules['fc_out'] = nn.Linear(layers[-2], layers[-1])
        modules['smax_out'] = nn.LogSoftmax(dim=1)

        return modules

我不确定这种事情在Python中是否允许和/或良好的实践。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-18 20:26:05

你的实现是允许的,而且是好的。

而且,您还可以初始化super().__init__() a,然后在循环中使用self.add_module(key, module)附加LinearRelu或其他任何后续内容。以这种方式,函数__init__可以覆盖init_modules的工作。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62438892

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档