为了便于管理可训练的超参数,我正在寻找一种方法来创建既充当nn.Parameter又充当nn.Module的类Hyperparameter。特别是,我希望将Hyperparameter对象用作nn.Parameter (例如,用于张量操作),但仍然可以访问nn.Module提供的接口,例如将对象与其他模块一起存储在nn.ModuleDict中,或者使用zero_grad()、parameters()等方法。
我试图通过多重继承来实现这一点,但它并不是很有效:
import torch
class Hyperparameter(torch.nn.Parameter, torch.nn.Module):
def __new__(cls, tensor, name):
return torch.nn.Parameter.__new__(cls, data=tensor)
def __init__(self, tensor, name):
torch.nn.Parameter.__init__(self)
torch.nn.Module.__init__(self)
self.register_parameter(name, self)
hp1 = Hyperparameter(torch.ones(5), "test1")
hp2 = Hyperparameter(torch.ones(8), "test2")
# Examples of what I want to do, which already work
tmp = hp1 * hp2
hp_dict = torch.nn.ModuleDict({"hp1": hp1, "hp2": hp2})
# What does not work with this solution
hp_dict.to(torch.device("cpu"))
# KeyError: "attribute 'data' already exists"这适用于我所描述的事情(可以添加到ModuleDict,可以执行代数操作,...),但是调用to()会抛出错误。我认为有些东西不再像nn.Module期望的那样,但我不明白它是什么。
编辑:以下是请求的完整堆栈跟踪:
Traceback (most recent call last):
File "tmp.py", line 16, in <module>
hp_dict.to(torch.device("cpu"))
File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 612, in to
return self._apply(convert)
File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 359, in _apply
module._apply(fn)
File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 384, in _apply
param.data = param_applied
File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 796, in __setattr__
self.register_parameter(name, value)
File ".myenv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 316, in register_parameter
raise KeyError("attribute '{}' already exists".format(name))
KeyError: "attribute 'data' already exists"发布于 2021-02-10 23:40:31
我不明白为什么你需要在同一个对象上同时使用nn.Module和nn.Parameter。你可以有一个基本上就是参数的nn.Module:
class Hyperparameter(torch.nn.Module):
def __init__(self, tensor, name):
super(Hyperparameter, self).__init__()
self.register_parameter(name=name, param=nn.Parameter(tensor))
self._name = name
def forward(self):
return getattr(self, self._mame) # expose the parameter via forward现在,您可以拥有一个Hyperparameter模块:
my_hp = Hyperparameter(name='hyper', data=torch.arange(3.).requires_grad_(True))现在,您可以通过多种方式访问超参数hyper:
In [1]: list(my_hp.named_parameters())
Out[1]:
[('hyper', Parameter containing:
tensor([0., 1., 2.], requires_grad=True))]
In [2]: my_hp.hyper
Out[2]:
Parameter containing:
tensor([0., 1., 2.], requires_grad=True)
In [3]: my_hp()
Out[3]:
Parameter containing:
tensor([0., 1., 2.], requires_grad=True)https://stackoverflow.com/questions/66133229
复制相似问题