首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Torch.jit.script(模块)与@torch.jit.script装饰器

Torch.jit.script(模块)与@torch.jit.script装饰器
EN

Stack Overflow用户
提问于 2020-08-13 18:57:55
回答 1查看 2.4K关注 0票数 2

为什么添加装饰符"@torch.jit.script“会导致错误,而我可以在该模块上调用torch.jit.script,例如,这会失败:

代码语言:javascript
复制
import torch
    
@torch.jit.script
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
代码语言:javascript
复制
"C:\Users\Administrator\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\jit\__init__.py", line 1262, in script
    raise RuntimeError("Type '{}' cannot be compiled since it inherits"
RuntimeError: Type '<class '__main__.MyCell'>' cannot be compiled since it inherits from nn.Module, pass an instance instead

虽然以下代码运行良好:

代码语言:javascript
复制
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
    
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h
    
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.script(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

这个问题也出现在PyTorch论坛上。

EN

回答 1

Stack Overflow用户

发布于 2020-08-13 19:56:05

您错误的原因是这里,这个公告正是:

不支持继承或任何其他多态性策略,但从对象继承以指定新样式类除外。

此外,如上面所述:

TorchScript类支持是实验性的。目前,它最适合于简单的类记录类型(比如带有方法的NamedTuple )。

目前,它的目的是用于简单的Python类(请参阅我提供的链接中的其他要点)和函数,请参阅我提供的链接以获得更多信息。

您还可以检查源代码以更好地了解它的工作原理。

从表面上看,当您传递一个实例时,应该保留的所有attributes都是递归解析的(来源)。您可以遵循这个函数(非常注释,但是对于一个答案来说太长了,请参见这里),尽管我不知道为什么会这样(以及为什么它是这样设计的)(所以希望有torch.jit内部工作经验的人会更多地谈论它)。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63401585

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档