首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在torchscript中使用自定义python对象

如何在torchscript中使用自定义python对象
EN

Stack Overflow用户
提问于 2020-06-09 17:17:15
回答 1查看 2.1K关注 0票数 2

我准备好将pytorch模块转换为ScriptModule,然后将其加载到c++,but中。我被这个错误This attribute exists on the Python module, but we failed to convert Python type: 'Vocab' to a TorchScript type阻塞了,Vocab是我定义的python对象。演示代码如下:

代码语言:javascript
复制
import torch
class Vocab(object):
    def __init__(self, name):
        self.name = name

    def show(self):
        print("dict:" + self.name)

class Model(torch.nn.Module):
    def __init__(self, ):
        super(Model, self).__init__()
        self.layers = torch.nn.Linear(2, 3)
        self.encoder = 4
        self.vocab = Vocab("vocab")

    def forward(self, x):
        name = self.vocab.name
        print("forward show encoder:" + str(self.encoder))
        print("vocab:" + name)
        enc_hidden = []
        step = len(x) // 2
        for i in range(step):
            enc_hidden.append((x[2*i] + x[2*i + 1])/2)
        enc_hidden = torch.stack(enc_hidden, 0)
        enc_hidden = self.__show(enc_hidden)
        return self.layers(enc_hidden)

    @torch.jit.export
    def __show(self, x):
        return x + 1

model = Model()
data = torch.randn(10, 2)
script_model = torch.jit.script(model)
print(script_model)
r1 = model(data)
print(r1)

错误消息:

代码语言:javascript
复制
Traceback (most recent call last):
  File "/mnt/d/python_projects/pytorch_deploy/model4.py", line 47, in <module>
    script_model = torch.jit.script(model)
  File "/mnt/d/anaconda3/lib/python3.6/site-packages/torch/jit/__init__.py", line 1261, in script
    return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
  File "/mnt/d/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 305, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/mnt/d/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 361, in create_script_module_impl
    create_methods_from_stubs(concrete_type, stubs)
  File "/mnt/d/anaconda3/lib/python3.6/site-packages/torch/jit/_recursive.py", line 279, in create_methods_from_stubs
    concrete_type._create_methods(defs, rcbs, defaults)
RuntimeError: 
  Module 'Model' has no attribute 'vocab' (This attribute exists on the Python module, but we failed to convert Python type: 'Vocab' to a TorchScript type.):
  File "/mnt/d/python_projects/pytorch_deploy/model4.py", line 26
  def forward(self, x):
    name = self.vocab.name
           ~~~~~~~~~~ <--- HERE
    print("forward show encoder:" + str(self.encoder))
    print("vocab:" + name)

那么我如何在torchscript中使用我自己的python对象呢?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-09 18:02:53

你必须像这样用torchscript.jit注释你的Vocab

代码语言:javascript
复制
@torch.jit.script
class Vocab(object):
    def __init__(self, name: str):
        self.name = name

    def show(self):
        print("dict:" + self.name)

还要注意规范name: str,因为torchscript也需要它来推断它的类型(PyTorch支持>=Python3.6类型注释,你也可以使用注释,但它不太清楚)。

请在那里查看Torchscript classesDefault Types以及其他相关的torchscript信息。

票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62279080

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档