首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >创建既是nn.Parameter又是nn.Module的类

创建既是nn.Parameter又是nn.Module的类
EN

Stack Overflow用户
提问于 2021-02-10 16:22:51
回答 1查看 164关注 0票数 0

为了便于管理可训练的超参数,我正在寻找一种方法来创建既充当nn.Parameter又充当nn.Module的类Hyperparameter。特别是,我希望将Hyperparameter对象用作nn.Parameter (例如,用于张量操作),但仍然可以访问nn.Module提供的接口,例如将对象与其他模块一起存储在nn.ModuleDict中,或者使用zero_grad()parameters()等方法。

我试图通过多重继承来实现这一点,但它并不是很有效:

代码语言:javascript
复制
import torch

class Hyperparameter(torch.nn.Parameter, torch.nn.Module):
    def __new__(cls, tensor, name):
        return torch.nn.Parameter.__new__(cls, data=tensor)

    def __init__(self, tensor, name):
        torch.nn.Parameter.__init__(self)
        torch.nn.Module.__init__(self)
        self.register_parameter(name, self)

hp1 = Hyperparameter(torch.ones(5), "test1")
hp2 = Hyperparameter(torch.ones(8), "test2")

# Examples of what I want to do, which already work
tmp = hp1 * hp2
hp_dict = torch.nn.ModuleDict({"hp1": hp1, "hp2": hp2})

# What does not work with this solution
hp_dict.to(torch.device("cpu"))
# KeyError: "attribute 'data' already exists"

这适用于我所描述的事情(可以添加到ModuleDict,可以执行代数操作,...),但是调用to()会抛出错误。我认为有些东西不再像nn.Module期望的那样,但我不明白它是什么。

编辑:以下是请求的完整堆栈跟踪:

代码语言:javascript
复制
Traceback (most recent call last):
  File "tmp.py", line 16, in <module>
    hp_dict.to(torch.device("cpu"))
  File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 612, in to
    return self._apply(convert)
  File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 359, in _apply
    module._apply(fn)
  File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 384, in _apply
    param.data = param_applied
  File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 796, in __setattr__
    self.register_parameter(name, value)
  File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 316, in register_parameter
    raise KeyError("attribute '{}' already exists".format(name))
KeyError: "attribute 'data' already exists"
EN

回答 1

Stack Overflow用户

发布于 2021-02-10 23:40:31

我不明白为什么你需要在同一个对象上同时使用nn.Modulenn.Parameter。你可以有一个基本上就是参数的nn.Module

代码语言:javascript
复制
class Hyperparameter(torch.nn.Module):
    def __init__(self, tensor, name):
        super(Hyperparameter, self).__init__()
        self.register_parameter(name=name, param=nn.Parameter(tensor))
        self._name = name
        
    def forward(self):
      return getattr(self, self._mame)  # expose the parameter via forward

现在,您可以拥有一个Hyperparameter模块:

代码语言:javascript
复制
my_hp = Hyperparameter(name='hyper', data=torch.arange(3.).requires_grad_(True))

现在,您可以通过多种方式访问超参数hyper

代码语言:javascript
复制
In [1]: list(my_hp.named_parameters())
Out[1]:
[('hyper', Parameter containing:
  tensor([0., 1., 2.], requires_grad=True))]

In [2]: my_hp.hyper
Out[2]:
Parameter containing:
tensor([0., 1., 2.], requires_grad=True)

In [3]: my_hp()
Out[3]:
Parameter containing:
tensor([0., 1., 2.], requires_grad=True)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66133229

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档