首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >BERT预训练MLM + NSP

BERT预训练MLM + NSP
EN

Stack Overflow用户
提问于 2021-11-26 10:21:05
回答 1查看 195关注 0票数 0

我想为传销+ NSP任务预先训练BERT。当我运行下面的代码时,抛出了一个错误:

RuntimeError:张量a (882)的大小必须与非单一维数为1 1%|▊| 3/561 00:02<06:13,1.49it/s的张量b (512)的大小匹配

这看起来像是一个截断问题。但是为什么呢?我只是使用了库。如果有人能开导我,我会很高兴。谢谢你的预支。

代码语言:javascript
复制
The code I run:

from transformers import BertTokenizer
from transformers import BertConfig, BertForPreTraining
from transformers import TextDatasetForNextSentencePrediction
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling

TOKENIZER_PATH = "hukuk_tokenizer"
MAX_LEN = 512
BLOCK_SIZE = 128
DATA_PATH = "data/toy_sentences_v3.removed_long_sent.txt"
OUTPUT_DIR = "/home/osahin/bert_yoktez/results/"
config = BertConfig()

if TOKENIZER_PATH == "hukuk_tokenizer":

        config.update({"vocab_size":30000})


print("config: ",config)

tokenizer = BertTokenizer.from_pretrained(TOKENIZER_PATH)
tokenizer.model_max_length= MAX_LEN
print("Tokenizer: ",tokenizer)

model = BertForPreTraining(config)

dataset= TextDatasetForNextSentencePrediction(
    tokenizer=tokenizer,
    file_path=DATA_PATH,
    block_size = BLOCK_SIZE
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability= 0.15
)

training_args = TrainingArguments(
    output_dir= OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size= 32,
    save_steps=1000,
    save_on_each_node=True,
    prediction_loss_only=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

trainer.train()

注意:对于NSP任务,输入文件是以每行一句为单位准备的。

EN

回答 1

Stack Overflow用户

发布于 2021-11-26 13:00:58

错误The size of tensor a (882) must match the size of tensor b (512) at non-singleton dimension很可能意味着模型支持的最大文本大小是512个标记,但是您尝试向其传递一个包含882个标记的文本。要绕过这一点,您可以在管道中的某个地方启用截断(最有可能的是,在文本标记化的时刻,即在TextDatasetForNextSentencePrediction中或在其创建之后)。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70122842

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档