首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用T5实现句子嵌入

使用T5实现句子嵌入
EN

Stack Overflow用户
提问于 2020-10-29 02:35:23
回答 1查看 1.1K关注 0票数 2

我想使用最先进的LM T5来获得句子嵌入向量。我发现了这个仓库https://github.com/UKPLab/sentence-transformers,据我所知,在BERT中,我应该将第一个令牌作为CLS令牌,它将是句子嵌入。在这个存储库中,我在T5模型上看到了相同的行为:

代码语言:javascript
复制
cls_tokens = output_tokens[:, 0, :]  # CLS token is first token

这种行为正确吗?我从T5获取了编码器,并用它对两个短语进行了编码:

代码语言:javascript
复制
"I live in the kindergarden"
"Yes, I live in the kindergarden"

它们之间的余弦相似度仅为"0.2420“。

我只需要理解句子嵌入是如何工作的--我应该训练网络来找到相似度以获得正确的结果吗?或者我,这是足够的基础预训练语言模型?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-10-29 07:20:45

为了从T5获得句子嵌入,您需要从T5编码器输出中获取last_hidden_state

代码语言:javascript
复制
model.encoder(input_ids=s, attention_mask=attn, return_dict=True)
pooled_sentence = output.last_hidden_state # shape is [batch_size, seq_len, hidden_size]
# pooled_sentence will represent the embeddings for each word in the sentence
# you need to sum/average the pooled_sentence
pooled_sentence = torch.mean(pooled_sentence, dim=1)

您现在有了来自T5的句子嵌入

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64579258

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档