首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >火炬nn.embedding误差

火炬nn.embedding误差
EN

Stack Overflow用户
提问于 2018-07-21 12:27:18
回答 1查看 2.7K关注 0票数 1

我正在阅读字嵌入上的py手电文档。

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(5)

word_to_ix = {"hello": 0, "world": 1, "how":2, "are":3, "you":4}
embeds = nn.Embedding(2, 5)  # 2 words in vocab, 5 dimensional embeddings
lookup_tensor = torch.tensor(word_to_ix["hello"], dtype=torch.long)
hello_embed = embeds(lookup_tensor)
print(hello_embed)

输出:

代码语言:javascript
复制
tensor([-0.4868, -0.6038, -0.5581,  0.6675, -0.1974])

这看起来不错,但是如果我将lookup_tensor行替换为

代码语言:javascript
复制
lookup_tensor = torch.tensor(word_to_ix["how"], dtype=torch.long)

我得到的错误是:

RuntimeError: index out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1524590658547/work/aten/src/TH/generic/THTensorMath.c:343

我不明白它为什么会给RunTime在hello_embed = embeds(lookup_tensor)上的错误。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-07-22 15:38:57

当您声明embeds = nn.Embedding(2, 5)时,词汇表大小为2,嵌入大小为5。也就是说,每个单词都将由大小为5的向量表示,而词汇表中只有2个单词。

lookup_tensor = torch.tensor(word_to_ix["how"], dtype=torch.long)嵌入将尝试查找与拼写中的第三个单词对应的向量,但是嵌入的词汇量为2,这就是为什么会出现错误。

如果您声明embeds = nn.Embedding(5, 5),它应该可以正常工作。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51456059

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档