首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >多个参数值

多个参数值
EN

Stack Overflow用户
提问于 2020-12-10 14:49:39
回答 1查看 122关注 0票数 0

我试图用pysyft引用来转换这个传递它的代码

就像这样:

代码语言:javascript
复制
class SyNet(sy.Module):
  def __init__(self,embedding_size, num_numerical_cols, output_size, layers, p ,torch_ref):
    super(SyNet, self ).__init__(  embedding_size, num_numerical_cols , output_size , layers , p=0.4  ,torch_ref=torch_ref  )
    self.all_embeddings=self.torch_ref.nn.ModuleList([nn.Embedding(ni, nf) for ni, nf in embedding_size])
    self.embedding_dropout=self.torch_ref.nn.Dropout(p)
    self.batch_norm_num=self.torch_ref.nn.BatchNorm1d(num_numerical_cols)

    all_layers= []
    num_categorical_cols = sum((nf for ni, nf in embedding_size))
    input_size = num_categorical_cols + num_numerical_cols

    for i in layers:
      all_layers.append(self.torch_ref.nn.Linear(input_size,i))
      all_layers.append(self.torch_ref.nn.ReLU(inplace=True))
      all_layers.append(self.torch_ref.nn.BatchNorm1d(i))
      all_layers.append(self.torch_ref.nn.Dropout(p))
      input_size = i

    all_layers.append(self.torch_ref.nn.Linear(layers[-1], output_size))

    self.layers = self.torch_ref.nn.Sequential(*all_layers)

  def forward(self, x_categorical, x_numerical):
    embeddings= []
    for i,e in enumerate(self.all_embeddings):
      embeddings.append(e(x_categorical[:,i]))

    x_numerical = self.batch_norm_num(x_numerical)
    x = self.torch_ref.cat([x, x_numerical], 1)
    x = self.layers(x)
    return x

但是当我试图创建模型的一个实例时

代码语言:javascript
复制
model = SyNet( categorical_embedding_sizes, numerical_data.shape[1], 2, [200,100,50], p=0.4 ,torch_ref= th)

我得到了一个TypeError

TypeError:参数'torch_ref'的多个值

我试图更改参数的顺序,但在位置参数方面出现了错误。你能帮我吗,我在课程和功能方面不是很有经验(oop)

提前谢谢你!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-10 15:04:00

查看PySyft源代码 for Module。类父类的构造函数只接受一个参数:torch_ref

因此,您应该使用以下方法调用超级构造函数:

代码语言:javascript
复制
super(SyNet, self).__init__(torch_ref=torch_ref) # line 3

从调用中删除除torch_ref以外的所有参数。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65236793

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档