首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >BERT模型如何选择标签排序?

BERT模型如何选择标签排序?
EN

Stack Overflow用户
提问于 2021-04-21 14:15:39
回答 1查看 185关注 0票数 1

我正在训练BertForSequenceClassification完成一项分类任务。我的数据集由“包含不利影响”(1)和“不包含不利影响”(0)组成。数据集包含所有的1,然后是0(数据不会被打乱)。为了训练,我对我的数据进行了混洗,并获得了日志。据我所知,logits是softmax之前的概率分布。例如logit是-4.673831,4.7095485。第一个值是否对应于标签1(包含AE),因为它首先出现在数据集中或标签0中。任何帮助都将不胜感激。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-21 20:19:13

第一个值对应于标签0,第二个值对应于标签1。BertForSequenceClassification所做的是将池子的输出馈送到线性层(在丢失之后,我将在此答案中忽略它)。让我们看一下下面的例子:

代码语言:javascript
复制
from torch import nn
from transformers import BertModel, BertTokenizer
t = BertTokenizer.from_pretrained('bert-base-uncased')
m = BertModel.from_pretrained('bert-base-uncased')
i = t.encode_plus('This is an example.', return_tensors='pt')
o = m(**i)
print(o.pooler_output.shape)

输出:

代码语言:javascript
复制
torch.Size([1, 768])

pooled_output是形状batch_size,hidden_size的张量,表示输入序列的上下文化(即应用了注意力) [CLS]标记。此张量被馈送到线性层,以计算序列的logits

代码语言:javascript
复制
classificationLayer = nn.Linear(768,2)
logits = classificationLayer(o.pooler_output)

当我们标准化这些logit时,我们可以看到线性层预测我们的输入应该属于标签1:

代码语言:javascript
复制
print(nn.functional.softmax(logits,dim=-1))

输出(由于线性层是随机初始化的,因此会有所不同):

代码语言:javascript
复制
tensor([[0.1679, 0.8321]], grad_fn=<SoftmaxBackward>)

线性图层应用了线性变换:y=xA^T+b,您已经可以看到线性图层并不知道您的标签。它‘唯一’有一个大小为2,768的权重矩阵来产生大小为1,2的logit

代码语言:javascript
复制
import torch:

logitsOwnCalculation = torch.matmul(o.pooler_output,  classificationLayer.weight.transpose(0,1))+classificationLayer.bias
print(nn.functional.softmax(logitsOwnCalculation,dim=-1))

输出:

代码语言:javascript
复制
tensor([[0.1679, 0.8321]], grad_fn=<SoftmaxBackward>)

BertForSequenceClassification模型通过应用CrossEntropyLoss来学习。当某个类别(在您的案例中为label)的逻辑仅稍微偏离预期时,此损失函数会产生小的损失。这意味着CrossEntropyLoss让你的模型知道,当输入does not contain adverse effect时,第一个logit应该是高的,当它是contains adverse effect时,它应该是小的。对于我们的示例,您可以使用以下内容进行检查:

代码语言:javascript
复制
loss_fct = nn.CrossEntropyLoss()
label0 = torch.tensor([0]) #does not contain adverse effect
label1 = torch.tensor([1]) #contains adverse effect
print(loss_fct(logits, label0))
print(loss_fct(logits, label1))

输出:

代码语言:javascript
复制
tensor(1.7845, grad_fn=<NllLossBackward>)
tensor(0.1838, grad_fn=<NllLossBackward>)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67190212

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档