首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用TorchScript类作为pytorch模块中的成员

使用TorchScript类作为pytorch模块中的成员
EN

Stack Overflow用户
提问于 2019-11-23 00:43:23
回答 1查看 536关注 0票数 2

我正在尝试让一些现有的pytorch模型支持TorchScript jit编译器,但是我遇到了非原语类型的成员问题。

这个小例子说明了这个问题:

代码语言:javascript
复制
import torch

@torch.jit.script
class Factory(object):
    def __init__(self):
        pass

    def create(self, x: float) -> torch.Tensor:
        return torch.tensor([x])

class Foo(torch.nn.Module):
    def __init__(self):
        super(Foo, self).__init__()
        self.factory: Factory = Factory()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.factory.create(0)

mod = torch.jit.script(Foo())

运行时,jit编译器会给出错误

代码语言:javascript
复制
RuntimeError:
module has no attribute 'factory':
at example.py:17:15
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.factory.create(0)
               ~~~~~~~~~~~~ <--- HERE

我已经测试了forward方法中的jit可以使用Factory类,但是当我将它存储为成员时,它不会承认它。为什么会这样呢?有没有办法让jit编译器将这种成员保存到编译后的模块中?

EN

回答 1

Stack Overflow用户

发布于 2020-12-01 23:37:08

这是PyTorch中的一个bug,在你发布你的问题:https://discuss.pytorch.org/t/jit-scripted-attributes-inside-module/60645https://github.com/pytorch/pytorch/issues/27495后不久就解决了。

更新PyTorch应该可以解决这个问题。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58998441

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档