首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch损失不降低,验证精度保持不变。

pytorch损失不降低,验证精度保持不变。
EN

Stack Overflow用户
提问于 2021-11-17 15:14:13
回答 1查看 418关注 0票数 1

我正在text-classification任务上使用Pytorch训练一个模型(输出维数为5)。我的网络就像下面的代码一样被实现。

代码语言:javascript
复制
class GRU(nn.Module):

    def __init__(self, model_param: ModelParam):
        super(GRU, self).__init__()

        self.embedding = nn.Embedding(model_param.vocab_size, model_param.embed_dim)

        # Build with pre-trained embedding vectors, if given.
        if model_param.vocab_embedding is not None:
            self.embedding.weight.data.copy_(model_param.vocab_embedding)
            self.embedding.weight.requires_grad = False

        self.rnn = nn.GRU(model_param.embed_dim,
                          model_param.hidden_dim,
                          num_layers=2,
                          bias=True,
                          batch_first=True,
                          dropout=0.5,
                          bidirectional=False)

        self.dropout = nn.Dropout(0.5)

        self.fc = nn.Sequential(
            nn.Linear(in_features=model_param.hidden_dim, out_features=128),
            nn.Linear(in_features=128, out_features=model_param.output_dim)
        )

    def forward(self, x, labels=None):
        '''
            :param x: torch.tensor, of shape [batch_size, max_seq_len].
            :param labels: torch.tensor, of shape [batch_size]. Not used in this model.
            :return outputs: torch.tensor, of shape [batch_size, output_dim].
        '''

        # [batch_size, max_seq_len, embed_dim].
        features = self.dropout(self.embedding(x))
        
        # [batch_size, max_seq_len, hidden_dim].
        outputs, _ = self.rnn(features)

        # [batch_size, hidden_dim].
        outputs = outputs[:, -1, :]

        return self.fc(self.dropout(outputs))

我使用nn.CrossEntropyLoss()作为损失函数,使用optim.SGD作为优化器。给出了损失函数和优化器的定义。

代码语言:javascript
复制
# Loss function and optimizer.
loss_func = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=learning_rate, weight_decay=0.9)

我的训练过程大致如下。

代码语言:javascript
复制
            for batch in train_iter:

                optimizer.zero_grad()

                # The prediction of model, and its corresponding loss.
                prediction = model(batch.text.type(torch.LongTensor).to(device), batch.label.to(device))
                loss = loss_func(prediction, batch.label.to(device))

                loss.backward()
                optimizer.step()

                # Record total loss.
                epoch_losses.append(loss.item() / batch_size)

当我训练这个模型时,验证准确性和损失是这样报告的。

代码语言:javascript
复制
Epoch 1/300 valid acc: [0.839] (16668 in 19873), time spent 631.497 sec. Validate loss 1.506138. Best validate epoch is 1.
Epoch 2/300 valid acc: [0.839] (16668 in 19873), time spent 627.631 sec. Validate loss 1.577007. Best validate epoch is 2.
Epoch 3/300 valid acc: [0.839] (16668 in 19873), time spent 631.427 sec. Validate loss 1.580756. Best validate epoch is 3.
Epoch 4/300 valid acc: [0.839] (16668 in 19873), time spent 605.352 sec. Validate loss 1.581306. Best validate epoch is 4.
Epoch 5/300 valid acc: [0.839] (16668 in 19873), time spent 388.487 sec. Validate loss 1.581431. Best validate epoch is 5.
Epoch 6/300 valid acc: [0.839] (16668 in 19873), time spent 360.344 sec. Validate loss 1.581464. Best validate epoch is 6.
Epoch 7/300 valid acc: [0.839] (16668 in 19873), time spent 624.345 sec. Validate loss 1.581473. Best validate epoch is 7.
Epoch 8/300 valid acc: [0.839] (16668 in 19873), time spent 622.059 sec. Validate loss 1.581477. Best validate epoch is 8.
Epoch 9/300 valid acc: [0.839] (16668 in 19873), time spent 651.425 sec. Validate loss 1.581478. Best validate epoch is 9.
Epoch 10/300 valid acc: [0.839] (16668 in 19873), time spent 697.475 sec. Validate loss 1.581478. Best validate epoch is 10.
...

结果表明,验证损失在epoch 9之后并没有减少,验证精度自第一个时代以来保持不变(注意到,在我的数据集中,其中一个标签占83%,由此可以推断,我的模型倾向于将所有序列预测到同一个标签,但当我在另一个相对较不平衡的数据集上进行训练时,也会发生这种情况)。有没有人遇到过这种情况,B4?我想知道我在设计模型或训练过程中是否犯了错误。谢谢你的帮助XD。

更新于11月19日,我增加了一个数字,显示了损失如何表现,同时训练。从这一数字可以看出,训练损失和验证损失在第5次之后都是不变的。20个历次的训练和损失验证

EN

回答 1

Stack Overflow用户

发布于 2021-11-18 09:46:51

现在我发现损失并不会降低,主要是因为优化器中设置的权重衰减是高

代码语言:javascript
复制
optimizer = SGD(model.parameters(), lr=learning_rate, weight_decay=0.9)

所以我修正了这个,把重量衰减改为5e-5。

代码语言:javascript
复制
optimizer = SGD(model.parameters(), lr=learning_rate, weight_decay=5e-5)

这一次,我的网络损失开始减少。然而,在准确性方面没有任何改进。

代码语言:javascript
复制
Epoch 1/100 valid acc: [0.839] (16668 in 19873), time spent 398.154 sec. Validate loss 0.713456. Best validate epoch is 1.
Epoch 2/100 valid acc: [0.839] (16668 in 19873), time spent 572.057 sec. Validate loss 0.631721. Best validate epoch is 2.
Epoch 3/100 valid acc: [0.839] (16668 in 19873), time spent 580.867 sec. Validate loss 0.613186. Best validate epoch is 3.
Epoch 4/100 valid acc: [0.839] (16668 in 19873), time spent 561.953 sec. Validate loss 0.601883. Best validate epoch is 4.
Epoch 5/100 valid acc: [0.839] (16668 in 19873), time spent 564.913 sec. Validate loss 0.596573. Best validate epoch is 5.
Epoch 6/100 valid acc: [0.839] (16668 in 19873), time spent 574.525 sec. Validate loss 0.592848. Best validate epoch is 6.
Epoch 7/100 valid acc: [0.839] (16668 in 19873), time spent 580.885 sec. Validate loss 0.591074. Best validate epoch is 7.
Epoch 8/100 valid acc: [0.839] (16668 in 19873), time spent 455.228 sec. Validate loss 0.589787. Best validate epoch is 8.
Epoch 9/100 valid acc: [0.839] (16668 in 19873), time spent 582.756 sec. Validate loss 0.588691. Best validate epoch is 9.
Epoch 10/100 valid acc: [0.839] (16668 in 19873), time spent 583.997 sec. Validate loss 0.588260. Best validate epoch is 10.
Epoch 11/100 valid acc: [0.839] (16668 in 19873), time spent 599.630 sec. Validate loss 0.588224. Best validate epoch is 11.
Epoch 12/100 valid acc: [0.839] (16668 in 19873), time spent 597.713 sec. Validate loss 0.586977. Best validate epoch is 12.
Epoch 13/100 valid acc: [0.839] (16668 in 19873), time spent 605.038 sec. Validate loss 0.587937. Best validate epoch is 13.
Epoch 14/100 valid acc: [0.839] (16668 in 19873), time spent 598.712 sec. Validate loss 0.587059. Best validate epoch is 14.
Epoch 15/100 valid acc: [0.839] (16668 in 19873), time spent 409.344 sec. Validate loss 0.587293. Best validate epoch is 15.
...

训练损失的表现如下图所示。

我想知道的学习速度1e-3和重量衰减5e-5是否是合理的设置。我指定的批号是32。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70006954

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档