首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch矩阵分解嵌入误差

PyTorch矩阵分解嵌入误差
EN

Stack Overflow用户
提问于 2019-10-01 23:33:13
回答 1查看 372关注 0票数 0

我正在尝试使用单个隐藏层NN来执行矩阵分解。一般来说,我试图解决一个张量,V,维度为9724x300,其中库存中有9724个项目,300是潜在特征的任意数量。

我拥有的数据是一个9724x9724矩阵,X,其中的列和行表示相互点赞的数量。(例如,X0,1表示同时喜欢项目0和项目1的用户总数。对角线条目并不重要。

我的目标是使用MSE损失,这样Vi,:on Vj,:转置的点积就非常非常接近Xi,j。

下面是我从下面的链接改编的代码。

https://blog.fastforwardlabs.com/2018/04/10/pytorch-for-recommenders-101.html

代码语言:javascript
复制
import torch
from torch.autograd import Variable

class MatrixFactorization(torch.nn.Module):
    def __init__(self, n_items=len(movie_ids), n_factors=300):
        super().__init__()

        self.vectors = nn.Embedding(n_items, n_factors,sparse=True)


    def forward(self, i,j):
        return (self.vectors([i])*torch.transpose(self.vectors([j]))).sum(1)

    def predict(self, i, j):
        return self.forward(i, j)

model = MatrixFactorization(n_items=len(movie_ids),n_factors=300)
loss_fn = nn.MSELoss() 
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for i in range(len(movie_ids)):
    for j in range(len(movie_ids)):
    # get user, item and rating data
        rating = Variable(torch.FloatTensor([Xij[i, j]]))
        # predict
#         i = Variable(torch.LongTensor([int(i)]))
#         j = Variable(torch.LongTensor([int(j)]))
        prediction = model(i, j)
        loss = loss_fn(prediction, rating)

        # backpropagate
        loss.backward()

        # update weights
        optimizer.step()

返回的错误为:

代码语言:javascript
复制
TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not list

我对嵌入非常陌生。我曾尝试将嵌入替换为简单的浮点张量,但是我定义的MatrixFactorization类没有将该张量识别为需要优化的模型参数。

有没有想过我哪里错了?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-10-02 01:26:26

您正在将一个列表传递给self.vectors

代码语言:javascript
复制
return (self.vectors([i])*torch.transpose(self.vectors([j]))).sum(1)

在调用self.vectors()之前尝试将其转换为张量

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58188133

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档