首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >从BertForSequenceClassification中提取特征

从BertForSequenceClassification中提取特征
EN

Stack Overflow用户
提问于 2021-03-27 01:16:42
回答 1查看 246关注 0票数 1

大家好,目前我正在尝试开发一个用于冲突检测的模型。使用和微调BERT模型,我已经得到了相当统计的结果,但我认为使用其他一些功能,我可以获得更好的准确性。我把自己定位在这个Tutorial上。经过微调后,我的模型如下所示:

代码语言:javascript
复制
==== Embedding Layer ====

bert.embeddings.word_embeddings.weight                  (30000, 768)
bert.embeddings.position_embeddings.weight                (512, 768)
bert.embeddings.token_type_embeddings.weight                (2, 768)
bert.embeddings.LayerNorm.weight                              (768,)
bert.embeddings.LayerNorm.bias                                (768,)

==== First Transformer ====

bert.encoder.layer.0.attention.self.query.weight          (768, 768)
bert.encoder.layer.0.attention.self.query.bias                (768,)
bert.encoder.layer.0.attention.self.key.weight            (768, 768)
bert.encoder.layer.0.attention.self.key.bias                  (768,)
bert.encoder.layer.0.attention.self.value.weight          (768, 768)
bert.encoder.layer.0.attention.self.value.bias                (768,)
bert.encoder.layer.0.attention.output.dense.weight        (768, 768)
bert.encoder.layer.0.attention.output.dense.bias              (768,)
bert.encoder.layer.0.attention.output.LayerNorm.weight        (768,)
bert.encoder.layer.0.attention.output.LayerNorm.bias          (768,)
bert.encoder.layer.0.intermediate.dense.weight           (3072, 768)
bert.encoder.layer.0.intermediate.dense.bias                 (3072,)
bert.encoder.layer.0.output.dense.weight                 (768, 3072)
bert.encoder.layer.0.output.dense.bias                        (768,)
bert.encoder.layer.0.output.LayerNorm.weight                  (768,)
bert.encoder.layer.0.output.LayerNorm.bias                    (768,)

==== Output Layer ====

bert.pooler.dense.weight                                  (768, 768)
bert.pooler.dense.bias                                        (768,)
classifier.weight                                           (2, 768)
classifier.bias                                                 (2,)

我的下一步将是从这个模型中获取CLS令牌,将其与一些手工制作的功能组合在一起,并将它们提供给不同的模型(MLP)进行分类。有什么建议要怎么做吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-03-27 23:41:01

您可以使用bert模型的池化输出(提供给池化layes的CLS令牌的上下文嵌入):

代码语言:javascript
复制
from transformers import BertModel, BertTokenizer

#replace bert-base-uncased with the path to your saved model
t = BertTokenizer.from_pretrained('bert-base-uncased')
m = BertModel.from_pretrained('bert-base-uncased')


i = t.batch_encode_plus(['this is a sample', 'different sample'], padding=True,return_tensors='pt')
o = m(**i)

print(o.keys())
#shape [batch_size, 768]
print(o.pooler_output.shape)
useMe = o.pooler_output
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66821505

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档