首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从nn.module ()传递参数来转发torch nn.module的函数

如何从nn.module ()传递参数来转发torch nn.module的函数
EN

Stack Overflow用户
提问于 2019-03-14 07:16:53
回答 1查看 3.5K关注 0票数 2

我扩展了nn.Module来实现我的网络,它的前向功能是这样的.

代码语言:javascript
复制
def forward(self, X, **kwargs):

    batch_size, seq_len = X.size()

    length = kwargs['length']
    embedded = self.embedding(X) # [batch_size, seq_len, embedding_dim]
    if self.use_padding:
        if length is None:
            raise AttributeError("Length must be a tensor when using padding")
        embedded = nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=True)
        #print("Size of Embedded packed", embedded[0].size())


    hidden, cell = self.init_hidden(batch_size)
    if self.rnn_unit == 'rnn':
        out, _ = self.rnn(embedded, hidden)
    elif self.rnn_unit == 'lstm':
        out, (hidden, cell) = self.rnn(embedded, (hidden, cell))


    # unpack if padding was used
    if self.use_padding:
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first = True)

我像这样初始化了一个skorch NeuralNetClassifier

代码语言:javascript
复制
net = NeuralNetClassifier(
    model,
    criterion=nn.CrossEntropyLoss,
    optimizer=Adam, 
    max_epochs=8, 
    lr=0.01, 
    batch_size=32
)

现在如果我调用net.fit(X, y, length=X_len),它会抛出一个错误

代码语言:javascript
复制
TypeError: __call__() got an unexpected keyword argument 'length'

根据documentation函数需要一个fit_params字典,

**fit_params :删除传递给模块的forward方法和self.train\_split调用的附加参数。

而且源代码总是将我的参数发送到train_split,在那里我的关键字参数显然不会被识别。

有没有办法把参数传递给我的前向函数?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-03-27 10:26:49

fit_params参数用于传递与数据拆分和模型相关的信息,就像拆分组一样。

在您的示例中,您要通过fit_params将额外的数据传递给模块,这不是它的目的。事实上,如果您在列车数据加载程序上启用批处理操作,那么您很容易就会遇到这样的麻烦,因为从那时起,您的长度和数据就错对齐了。

最好的方法已经在回答您关于问题跟踪器的问题中描述过了。

代码语言:javascript
复制
X_dict = {'X': X, 'length': X_len}
net.fit(X_dict, y)

由于skorch支持dict,所以您可以简单地将长度添加到输入dict中,并将其传递给模块,很好地批次,并通过相同的数据加载器进行传递。在您的模块中,您可以通过forward中的参数访问它。

代码语言:javascript
复制
def forward(self, X, length):
     return ...

关于这种行为的进一步文档可以找到在医生里

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55156877

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档