首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何还原BERT/XLNet嵌入?

如何还原BERT/XLNet嵌入?
EN

Stack Overflow用户
提问于 2020-04-03 01:23:44
回答 1查看 260关注 0票数 2

我最近一直在尝试堆叠语言模型,并注意到一些有趣的事情: BERT和XLNet的输出嵌入与输入嵌入不同。例如,下面的代码片段:

代码语言:javascript
复制
bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
tok = transformers.BertTokenizer.from_pretrained("bert-base-cased")

sent = torch.tensor(tok.encode("I went to the store the other day, it was very rewarding."))
enc = bert.get_input_embeddings()(sent)
dec = bert.get_output_embeddings()(enc)

print(tok.decode(dec.softmax(-1).argmax(-1)))

为我输出以下内容:

代码语言:javascript
复制
,,,,,,,,,,,,,,,,,

我本来希望返回(格式化的)输入序列,因为我的印象是输入和输出标记嵌入是绑定的。

有趣的是,大多数其他模型都没有表现出这种行为。例如,如果您在GPT2、Albert或Roberta上运行相同的代码片段,它将输出输入序列。

这是一个bug吗?或者这是BERT/XLNet所期望的?

EN

回答 1

Stack Overflow用户

发布于 2020-12-13 00:51:46

我不确定是否太晚了,但我已经用你的代码做了一点实验,它可以被还原。:)

代码语言:javascript
复制
bert = transformers.BertForMaskedLM.from_pretrained("bert-base-cased")
tok = transformers.BertTokenizer.from_pretrained("bert-base-cased")

sent = torch.tensor(tok.encode("I went to the store the other day, it was very rewarding."))
print("Initial sentence:", sent)
enc = bert.get_input_embeddings()(sent)
dec = bert.get_output_embeddings()(enc)

print("Decoded sentence:", tok.decode(dec.softmax(0).argmax(1)))

为此,您将获得以下输出:

代码语言:javascript
复制
Initial sentence: tensor([  101,   146,  1355,  1106,  1103,  2984,  1103,  1168,  1285,   117,
         1122,  1108,  1304, 10703,  1158,   119,   102])  
Decoded sentence: [CLS] I went to the store the other day, it was very rewarding. [SEP]
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60997438

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档