首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用DistilBERT生成文本的句子

使用DistilBERT生成文本的句子
EN

Stack Overflow用户
提问于 2020-05-25 02:10:52
回答 1查看 490关注 0票数 1

您好,我已经使用了非常棒的库huggingface转换器在GPT2中生成文本,效果非常好:

代码语言:javascript
复制
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Once upon a time there was")).unsqueeze(0)
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
greedy_output = model.generate(input_ids, max_length=50)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

我的问题是,现在我想用更小更简单的DistilmBERT模型做同样的事情,它也是104种语言的多语言,所以我想用这个轻便的模型生成例如西班牙语和英语的文本

我已经试过了

代码语言:javascript
复制
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-multilingual-cased')
model = DistilBertForMaskedLM.from_pretrained('distilbert-base-multilingual-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]

但我不确定这是不是正确的模型。一旦我得到了输出,我怎么才能得到这个短语的延续呢?

经过更多的测试,我可以让这一代很好地与distilgpt2一起工作,问题是我想使用轻量级的多语言模型DistilmBERT (distilbert-base- multilingual cased)来做多语言工作,有什么建议吗?

代码语言:javascript
复制
import torch
from transformers import *
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Once upon a time")).unsqueeze(0)
model = GPT2LMHeadModel.from_pretrained("distilgpt2", pad_token_id=tokenizer.eos_token_id)
greedy_output = model.generate(input_ids, max_length=50) #greedy search

sample_outputs = model.generate(
    input_ids,
    do_sample=True, 
    max_length=50, 
    top_k=50, 
    top_p=0.95, 
    temperature=1,
    num_return_sequences=3
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))`

感谢您的帮助:)

EN

回答 1

Stack Overflow用户

发布于 2020-06-03 18:18:18

我只是复制了LysandreJikhere的答案

不幸的是,DistilmBERT不能用于生成。这是由于原始BERT模型使用掩蔽语言建模(MLM)进行预训练的方式。因此,它同时关注左侧和右侧上下文(您试图生成的令牌的左侧和右侧的令牌),而对于生成,模型只能访问左侧上下文。

GPT-2使用因果语言建模(CLM)进行训练,这就是为什么它可以生成这样的相干序列。我们只为CLM模型实现生成方法,因为MLM模型不会生成任何连贯的东西。

在文档中,您可以找到适合任务的模型。

https://transformer.huggingface.co/中的一个快速示例

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61990266

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档