您好,我已经使用了非常棒的库huggingface转换器在GPT2中生成文本,效果非常好:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Once upon a time there was")).unsqueeze(0)
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
greedy_output = model.generate(input_ids, max_length=50)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))我的问题是,现在我想用更小更简单的DistilmBERT模型做同样的事情,它也是104种语言的多语言,所以我想用这个轻便的模型生成例如西班牙语和英语的文本
我已经试过了
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-multilingual-cased')
model = DistilBertForMaskedLM.from_pretrained('distilbert-base-multilingual-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]但我不确定这是不是正确的模型。一旦我得到了输出,我怎么才能得到这个短语的延续呢?
经过更多的测试,我可以让这一代很好地与distilgpt2一起工作,问题是我想使用轻量级的多语言模型DistilmBERT (distilbert-base- multilingual cased)来做多语言工作,有什么建议吗?
import torch
from transformers import *
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Once upon a time")).unsqueeze(0)
model = GPT2LMHeadModel.from_pretrained("distilgpt2", pad_token_id=tokenizer.eos_token_id)
greedy_output = model.generate(input_ids, max_length=50) #greedy search
sample_outputs = model.generate(
input_ids,
do_sample=True,
max_length=50,
top_k=50,
top_p=0.95,
temperature=1,
num_return_sequences=3
)
print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))`感谢您的帮助:)
发布于 2020-06-03 18:18:18
我只是复制了LysandreJik给here的答案
不幸的是,DistilmBERT不能用于生成。这是由于原始BERT模型使用掩蔽语言建模(MLM)进行预训练的方式。因此,它同时关注左侧和右侧上下文(您试图生成的令牌的左侧和右侧的令牌),而对于生成,模型只能访问左侧上下文。
GPT-2使用因果语言建模(CLM)进行训练,这就是为什么它可以生成这样的相干序列。我们只为CLM模型实现生成方法,因为MLM模型不会生成任何连贯的东西。
在文档中,您可以找到适合任务的模型。
https://stackoverflow.com/questions/61990266
复制相似问题