首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PureFrameworkTensorFoundError,运行时错误-FedeartedLearning

PureFrameworkTensorFoundError,运行时错误-FedeartedLearning
EN

Stack Overflow用户
提问于 2019-12-26 13:43:05
回答 1查看 298关注 0票数 1

我正在使用Pytorch尝试使用联邦学习的线性回归算法,我遇到了以下错误。我正在Colab上实现它。根据我的说法,这个错误可能是由于train()函数中的某些代码行造成的。好心的帮助是,你已经与Pysyft合作,并遇到过这样的错误。

代码语言:javascript
复制
RuntimeError: invalid argument 8: lda should be at least max(1, 0), but have 0 at /pytorch/aten/src/TH/generic/THBlas.cpp:363

代码如下:

代码语言:javascript
复制
#import the necessasry packages
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import syft as sy

#create target and data variables as tensors
x_data=Variable(torch.Tensor([[1.0],[0.0],[1.0],[0.0]]))
y_data=Variable(torch.Tensor([[0.0],[0.0],[1.0],[1.0]]))

#Create virtual Workers
hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")

data_bob = x_data[0:2]
target_bob = y_data[0:2]
data_alice = x_data[2:0]
target_alice = y_data[2:0]

#creating a class that does Linear Regression
class LinearRegression (nn.Module):

  def __init__(self):
    super(LinearRegression,self). __init__ ()
    self.linear = torch.nn.Linear(1,1)

  def forward(self, x):
    y_pred = self.linear(x)
    return y_pred

#assign the function to the variable name 'Model'
model=LinearRegression()

#send the data to the virtual worker pointers
data_bob = data_bob.send(bob)
data_alice = data_alice.send(alice)

target_bob = target_bob.send(bob)
target_alice = target_alice.send(alice)

# organize pointers into a list
datasets = [(data_bob,target_bob),(data_alice,target_alice)]

#create optimizer and calculate the loss
opt = torch.optim.SGD(params=model.parameters(),lr=0.1)
criterion = torch.nn.MSELoss(size_average=False)

def train():
  opt = torch.optim.SGD(params=model.parameters(),lr=0.1)
  for epoch in range (20):
    model.train()
    print("Training started..")

    for x_data,y_data in datasets:

      model.send(x_data.location) 

      opt.zero_grad()

       #forwardpass
       #the model here is the linear regression model
      y_pred = model(x_data)

      #ComputeLoss
      loss=criterion(y_pred,y_data)

      #BackwardPass
      loss.backward()

      opt.step()

      model.get() 

      print(loss.get())

train()
EN

回答 1

Stack Overflow用户

发布于 2019-12-27 18:35:18

你这里有一个打字错误:

代码语言:javascript
复制
data_alice = x_data[2:0]
target_alice = y_data[2:0]

应为[2:]

因为data_alice失败了,所以出现了这个错误。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59484278

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档