为什么添加装饰符"@torch.jit.script“会导致错误,而我可以在该模块上调用torch.jit.script,例如,这会失败:
import torch
@torch.jit.script
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)"C:\Users\Administrator\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\jit\__init__.py", line 1262, in script
raise RuntimeError("Type '{}' cannot be compiled since it inherits"
RuntimeError: Type '<class '__main__.MyCell'>' cannot be compiled since it inherits from nn.Module, pass an instance instead虽然以下代码运行良好:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)这个问题也出现在PyTorch论坛上。
发布于 2020-08-13 19:56:05
您错误的原因是这里,这个公告正是:
不支持继承或任何其他多态性策略,但从对象继承以指定新样式类除外。
此外,如上面所述:
TorchScript类支持是实验性的。目前,它最适合于简单的类记录类型(比如带有方法的NamedTuple )。
目前,它的目的是用于简单的Python类(请参阅我提供的链接中的其他要点)和函数,请参阅我提供的链接以获得更多信息。
您还可以检查源代码以更好地了解它的工作原理。
从表面上看,当您传递一个实例时,应该保留的所有attributes都是递归解析的(来源)。您可以遵循这个函数(非常注释,但是对于一个答案来说太长了,请参见这里),尽管我不知道为什么会这样(以及为什么它是这样设计的)(所以希望有torch.jit内部工作经验的人会更多地谈论它)。
https://stackoverflow.com/questions/63401585
复制相似问题