首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch LSTM将一系列特征向量映射到其标注

pytorch LSTM将一系列特征向量映射到其标注
EN

Stack Overflow用户
提问于 2021-05-24 22:06:08
回答 1查看 44关注 0票数 0

目前我有形状为( 50,25)的输入X,其中有50个特征向量,每个向量有25个维度。例如,X的数据如下:

代码语言:javascript
复制
X = [[0. 0. 0. ... 1. 1. 1.]
 [0. 0. 0. ... 1. 1. 1.]
 [0. 0. 0. ... 1. 1. 1.]
 ...
 [0. 0. 0. ... 1. 1. 1.]
 [0. 0. 0. ... 1. 1. 1.]
 [0. 0. 0. ... 1. 1. 1.]]

并且输出标签y是长度为50的[0 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]。即每个特征向量具有对应于y中的元素的标签。

如何构造pytorch LSTM,将输入对象重塑为3维,并正确解释输出对象?非常感谢你事先的帮助。

目前我有一个这样的LSTM模板,因为我的输入已经是数字的,我想去掉编码器/解码器部分,对吗?

代码语言:javascript
复制
class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0, tie_weights=False):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.ntoken = ntoken 
        self.decoder = nn.Linear(nhid, self.ntoken)
        if rnn_type in ['LSTM', 'GRU']:
            self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        else:
            try:
                nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
            except KeyError:
                raise ValueError( """An invalid option for `--model` was supplied,
                                 options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
            self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)

        self.init_weights()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        nn.init.zeros_(self.decoder.weight)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(input)
        emb = emb.transpose(1, 0)

        output, hidden = self.rnn(emb, hidden) #output of shape (length, batchsize, nhid)
        output = self.drop(output)
        output = output[-1, :, :] #shape (batchsize, nhid)

        decoded = self.decoder(output) #shape (batchsize, ntoken)
        return F.log_softmax(decoded, dim=1), hidden 

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        if self.rnn_type == 'LSTM':
            return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                    weight.new_zeros(self.nlayers, bsz, self.nhid))
        else:
            return weight.new_zeros(self.nlayers, bsz, self.nhid)

目前我写的列车是

代码语言:javascript
复制
X = X.reshape((1, 50, 25))
hidden = self.model.init_hidden(1)
for iter in range(0, self.epochs):
    data = torch.from_numpy(X)
    target = torch.LongTensor(y.reshape((1, torch.LongTensor(y).size(0))))
    self.model.zero_grad()
    self.optimizer.zero_grad()
    hidden = self.repackage_hidden(hidden)
  
    output, hidden = self.model(data.float(), hidden)   
    loss = self.criterion(output, target)
    loss.backward() 
    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.25) 
    self.optimizer.step()
    self.model.train() 

但是我得到了错误:RuntimeError: multi-target not supported at /tmp/pip-req-build-4baxydiv/aten/src/THNN/generic/ClassNLLCriterion.c:22

EN

回答 1

Stack Overflow用户

发布于 2021-05-24 23:16:53

rnn的输出形式为(长度,批大小,nhid),基于您的标签(每个样本1个数字)我假设您正在进行分类,所以通常我们会给分类器(self.decoder)最后一个时间步的输出特征。在这里,我将转发方法改为this,并得到了shape (batchsize,ntoken)的输出,它适合您的标签的形状。

代码语言:javascript
复制
def forward(self, input, hidden):
    emb = self.drop(self.encoder(input))
    emb = emb.transpose(1, 0) #(batchsize, length, ninp) => (length, batchsize, ninp)

    output, hidden = self.rnn(emb, hidden) #output of shape (length, batchsize, nhid)
    output = self.drop(output)
    output = output[-1, :, :] #shape (batchsize, nhid)

    decoded = self.decoder(output) #shape (batchsize, ntoken)
    return F.log_softmax(decoded, dim=1), hidden 

关于摆脱self.encoder,它是一个嵌入层,它采用一个索引数组并将每个索引替换为一个向量。如果你的输入包括某物的索引(int/long),你可以使用它,否则(它不是索引,而是像温度这样的浮点数,...)你应该摆脱它(因为它是错误的)。如果我的英语令人困惑,我很抱歉。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67673473

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档