首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >`return_sequences = False`‘等效于pytorch LSTM

`return_sequences = False`‘等效于pytorch LSTM
EN

Stack Overflow用户
提问于 2020-06-04 21:04:29
回答 1查看 3.8K关注 0票数 9

在tensorflow/keras中,我们可以简单地在分类/完全连接/激活(softmax/sigmoid)层之前为最后一个LSTM层设置return_sequences = False,以消除时间维。

在PyTorch中,我没有发现类似的东西。对于分类任务,我不需要序列来对模型进行排序,而是需要像这样的多到一个体系结构:

这是我的简单的双LSTM模型。

代码语言:javascript
复制
import torch
from torch import nn

class BiLSTMClassifier(nn.Module):
    def __init__(self):
        super(BiLSTMClassifier, self).__init__()
        self.embedding = torch.nn.Embedding(num_embeddings = 65000, embedding_dim = 64)
        self.bilstm = torch.nn.LSTM(input_size = 64, hidden_size = 8, num_layers = 2,
                                    batch_first = True, dropout = 0.2, bidirectional = True)
        # as we have 5 classes
        self.linear = nn.Linear(8*2*512, 5) # last dimension
    def forward(self, x):
        x = self.embedding(x)
        print(x.shape)
        x, _ = self.bilstm(x)
        print(x.shape)
        x = self.linear(x.reshape(x.shape[0], -1))
        print(x.shape)

# create our model

bilstmclassifier = BiLSTMClassifier()

如果我观察每一层后的形状,

代码语言:javascript
复制
xx = torch.tensor(X_encoded[0]).reshape(1,512)
print(xx.shape) 
# torch.Size([1, 512])
bilstmclassifier(xx)
#torch.Size([1, 512, 64])
#torch.Size([1, 512, 16])
#torch.Size([1, 5])

如何使最后一个LSTM返回形状为(1, 16)而不是(1, 512, 16)的张量?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-05 08:21:11

实现这一目的的最简单方法是索引到张量中:

代码语言:javascript
复制
x = x[:, -1, :]

其中x是RNN输出。当然,如果batch_firstFalse,则必须使用x[-1, :, :] (或仅使用x[-1])来索引时间轴。事实证明,这也是Tensorflow/Keras所做的事情。相关代码可在K.rnn 这里中找到。

代码语言:javascript
复制
last_output = tuple(o[-1] for o in outputs)

请注意,此时的代码使用time_major数据格式,因此索引进入第一个轴。另外,outputs是一个元组,因为它可以是多层、状态/单元对等,但是它通常是所有时间步骤的输出序列。

然后在RNN类中使用它,如下所示:

代码语言:javascript
复制
if self.return_sequences:
    output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths)
else:
    output = last_output

因此,总的来说,我们可以看到return_sequences=False只使用outputs[-1]

票数 14
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62204109

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档