首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何利用pytorch nn.Transformer进行序列分类?

如何利用pytorch nn.Transformer进行序列分类?
EN

Stack Overflow用户
提问于 2019-09-25 06:02:53
回答 2查看 3.8K关注 0票数 2

我正在使用nn.TransformerEncoder()执行序列分类任务。其管道类似于nn.LSTM()

我尝试过几种时态特征融合方法:

选择最终输出作为整个sequence.

  • Using的表示,采用仿射变换逐帧融合这些features.

  • Classifying,然后选择最大值作为整个序列的范畴。

但是,这3种方法的准确率都很差,只有25%的对4类进行分类。当使用nn.LSTM和最后一个隐藏状态时,我可以很容易地达到83%的精度。我尝试了大量的超参数的nn.TransformerEncoder(),但没有任何提高的准确性。

我现在不知道该怎么调整这个型号。你能给我一些实用的建议吗?谢谢。

对于LSTMforward()是:

代码语言:javascript
复制
    def forward(self, x_in, x_lengths, apply_softmax=False):

        # Embed
        x_in = self.embeddings(x_in)

        # Feed into RNN
        out, h_n = self.LSTM(x_in) #shape of out: T*N*D

        # Gather the last relevant hidden state
        out = out[-1,:,:] # N*D

        # FC layers
        z = self.dropout(out)
        z = self.fc1(z)
        z = self.dropout(z)
        y_pred = self.fc2(z)

        if apply_softmax:
            y_pred = F.softmax(y_pred, dim=1)
        return y_pred

对于transformer

代码语言:javascript
复制
    def forward(self, x_in, x_lengths, apply_softmax=False):

        # Embed
        x_in = self.embeddings(x_in)

        # Feed into RNN
        out = self.transformer(x_in)#shape of out T*N*D

        # Gather the last relevant hidden state
        out = out[-1,:,:] # N*D

        # FC layers
        z = self.dropout(out)
        z = self.fc1(z)
        z = self.dropout(z)
        y_pred = self.fc2(z)

        if apply_softmax:
            y_pred = F.softmax(y_pred, dim=1)
        return y_pred
EN

回答 2

Stack Overflow用户

发布于 2019-09-25 17:47:05

你提到的准确性表明出了问题。由于您正在比较LSTM和TransformerEncoder,我想指出一些关键的差异。

  1. Positional嵌入--:这是非常重要的,因为转换器没有递归概念,因此它不捕获序列信息。因此,请确保添加位置信息以及输入embeddings.
  2. Model体系结构d_modeln_headnum_encoder_layers非常重要。使用Vaswani等人,2017中使用的默认大小。( num_encoder_layers=6)
  3. Optimization:,n_head=8d_model=512)在很多情况下,人们发现变压器需要用更小的学习率,更大的批量,更小的WarmUpScheduling.

来训练。

最后但并非最不重要的一点是,为了进行健康检查,只需确保模型的参数正在更新。您还可以检查培训的准确性,以确保随着培训的进行,准确性不断提高。

虽然很难说出你的代码到底出了什么问题,但我希望以上几点会有所帮助!

票数 4
EN

Stack Overflow用户

发布于 2022-11-08 10:43:22

我不确定Selecting the final outputs as the representation of the whole sequence.对变压器是否正确。由于这些模型的工作方式与递归网络不同。最后一个时间点并不表示序列的完全嵌入。所以,使用上一次的观点,我认为你正在丢弃大量的信息。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58092004

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档