首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >通过Huggingface转换器更新BERT模型

通过Huggingface转换器更新BERT模型
EN

Stack Overflow用户
提问于 2019-10-30 15:19:24
回答 1查看 2.4K关注 0票数 9

我正在尝试使用内部语料库更新预训练的BERT模型。我看过Huggingface的transformer文档,你会发现我有点困惑,below.My的目标是使用余弦距离计算句子之间的简单相似度,但我需要为我的特定用例更新预先训练的模型。

如果你看一下下面的代码,这正是Huggingface文档中的代码。我试图“重新训练”或更新模型,我假设special_token_1和special_token_2表示来自我的“内部”数据或语料库的“新句子”。这是正确的吗?总而言之,我喜欢已经预训练的BERT模型,但我想使用另一个内部数据集来更新或重新训练它。任何线索都将不胜感激。

代码语言:javascript
复制
import tensorflow as tf
import tensorflow_datasets
from transformers import *

model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

SPECIAL_TOKEN_1="dogs are very cute"
SPECIAL_TOKEN_2="dogs are cute but i like cats better and my 
brother thinks they are more cute"

tokenizer.add_tokens([SPECIAL_TOKEN_1, SPECIAL_TOKEN_2])
model.resize_token_embeddings(len(tokenizer))
#Train our model
model.train()
model.eval()
EN

回答 1

Stack Overflow用户

发布于 2020-10-19 22:24:18

BERT在两个任务上进行了预训练:掩蔽语言建模(MLM)和下一句预测(NSP)。其中最重要的是传播学(事实证明,下一个句子预测任务对模型的语言理解能力没有多大帮助-例如,RoBERTa只对传播学进行了预训练)。

如果您想在自己的数据集上进一步训练模型,可以通过在Transformers存储库中使用BERTForMaskedLM来实现。这是BERT,顶部有一个语言建模头,它允许你在自己的数据集上执行掩蔽语言建模(即预测掩蔽标记)。下面是它的使用方法:

代码语言:javascript
复制
from transformers import BertTokenizer, BertForMaskedLM 
import torch   

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 
model = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True) 

inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt") 
labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]

outputs = model(**inputs, labels=labels) 
loss = outputs.loss 
logits = outputs.logits

您可以使用loss.backward()更新BertForMaskedLM的权重,这是训练PyTorch模型的主要方法。如果您不想自己执行此操作,Transformers库还提供了一个Python脚本,它允许您在自己的数据集上真正快速地执行传销。参见here ( "RoBERTa/BERT/DistilBERT和屏蔽语言建模“一节)。你只需要提供一个训练和测试文件。

您不需要添加任何特殊的令牌。特殊令牌的示例是CLS和SEP,它们用于序列分类和问答任务(以及其他任务)。这些是由tokenizer自动添加的。我怎么知道的?因为BertTokenizer继承自PretrainedTokenizer,如果您查看一下它的__call__方法here的文档,您可以看到add_special_tokens参数缺省为True。

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58620282

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档