我正在尝试适应PyTorch nn模块中的嵌入类。
我注意到还有很多人和我有同样的问题,因此在PyTorch讨论论坛和堆栈溢出上发布了一些问题,但我仍然有些困惑。
根据正式文件,传递的参数是num_embeddings和embedding_dim,它们分别表示我们的字典(或词汇表)有多大,以及我们希望嵌入的维度分别是多少。
我不明白的是我该怎么解释这些。例如,我运行的小练习代码:
import torch
import torch.nn as nn
embedding = nn.Embedding(num_embeddings=10, embedding_dim=3)
a = torch.LongTensor([[1, 2, 3, 4], [4, 3, 2, 1]]) # (2, 4)
b = torch.LongTensor([[1, 2, 3], [2, 3, 1], [4, 5, 6], [3, 3, 3], [2, 1, 2],
[6, 7, 8], [2, 5, 2], [3, 5, 8], [2, 3, 6], [8, 9, 6],
[2, 6, 3], [6, 5, 4], [2, 6, 5]]) # (13, 3)
c = torch.LongTensor([[1, 2, 3, 2, 1, 2, 3, 3, 3, 3, 3],
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]]) # (2, 11)当我通过a、b和c变量运行embedding时,会得到形状为(2, 4, 3)、(13, 3, 3)、(2, 11, 3)的嵌入结果。
让我困惑的是,我认为我们拥有的样本数量超过了预定义的嵌入数,我们应该得到一个错误?既然我定义的embedding有10嵌入,那么b是否应该给我一个错误,因为它是包含13个维度3字的张量?
发布于 2019-11-09 10:55:04
在您的例子中,下面是如何解释输入张量:
a = torch.LongTensor([[1, 2, 3, 4], [4, 3, 2, 1]]) # 2 sequences of 4 elements此外,您的嵌入层是这样解释的:
embedding = nn.Embedding(num_embeddings=10, embedding_dim=3) # 10 distinct elements and each those is going to be embedded in a 3 dimensional space因此,只要输入张量在[0, 9]范围内,输入张量是否超过10个元素并不重要。例如,如果我们创建两个元素的张量,如:
d = torch.LongTensor([[1, 10]]) # 1 sequence of 2 elements当我们通过嵌入层传递这个张量时,我们会得到以下错误:
RuntimeError:超出范围的索引:尝试访问包含9行的表之外的索引10
概括地说,num_embeddings是词汇表中唯一元素的总数,而embedding_dim是每个嵌入向量通过嵌入层后的大小。因此,只要张量中的每个元素都在10+范围内,就可以有一个[0, 9]元素的张量,因为您定义了10个元素的词汇表大小。
https://stackoverflow.com/questions/58777282
复制相似问题