首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我应该如何理解nn.Embeddings参数num_embeddings和embedding_dim?

我应该如何理解nn.Embeddings参数num_embeddings和embedding_dim?
EN

Stack Overflow用户
提问于 2019-11-09 07:27:40
回答 1查看 1.8K关注 0票数 1

我正在尝试适应PyTorch nn模块中的嵌入类。

我注意到还有很多人和我有同样的问题,因此在PyTorch讨论论坛和堆栈溢出上发布了一些问题,但我仍然有些困惑。

根据正式文件,传递的参数是num_embeddingsembedding_dim,它们分别表示我们的字典(或词汇表)有多大,以及我们希望嵌入的维度分别是多少。

我不明白的是我该怎么解释这些。例如,我运行的小练习代码:

代码语言:javascript
复制
import torch
import torch.nn as nn


embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)

a = torch.LongTensor([[1, 2, 3, 4], [4, 3, 2, 1]]) # (2, 4)

b = torch.LongTensor([[1, 2, 3], [2, 3, 1], [4, 5, 6], [3, 3, 3], [2, 1, 2],
                      [6, 7, 8], [2, 5, 2], [3, 5, 8], [2, 3, 6], [8, 9, 6],
                      [2, 6, 3], [6, 5, 4], [2, 6, 5]]) # (13, 3)

c = torch.LongTensor([[1, 2, 3, 2, 1, 2, 3, 3, 3, 3, 3],
                      [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]) # (2, 11)

当我通过abc变量运行embedding时,会得到形状为(2, 4, 3)(13, 3, 3)(2, 11, 3)的嵌入结果。

让我困惑的是,我认为我们拥有的样本数量超过了预定义的嵌入数,我们应该得到一个错误?既然我定义的embedding10嵌入,那么b是否应该给我一个错误,因为它是包含13个维度3字的张量?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-11-09 10:55:04

在您的例子中,下面是如何解释输入张量:

代码语言:javascript
复制
a = torch.LongTensor([[1, 2, 3, 4], [4, 3, 2, 1]]) # 2 sequences of 4 elements

此外,您的嵌入层是这样解释的:

代码语言:javascript
复制
embedding = nn.Embedding(num_embeddings=10, embedding_dim=3) # 10 distinct elements and each those is going to be embedded in a 3 dimensional space

因此,只要输入张量在[0, 9]范围内,输入张量是否超过10个元素并不重要。例如,如果我们创建两个元素的张量,如:

代码语言:javascript
复制
d = torch.LongTensor([[1, 10]]) # 1 sequence of 2 elements

当我们通过嵌入层传递这个张量时,我们会得到以下错误:

RuntimeError:超出范围的索引:尝试访问包含9行的表之外的索引10

概括地说,num_embeddings是词汇表中唯一元素的总数,而embedding_dim是每个嵌入向量通过嵌入层后的大小。因此,只要张量中的每个元素都在10+范围内,就可以有一个[0, 9]元素的张量,因为您定义了10个元素的词汇表大小。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58777282

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档