首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法将pytorch模型转换为TorchScript格式

无法将pytorch模型转换为TorchScript格式
EN

Stack Overflow用户
提问于 2022-04-25 17:04:41
回答 1查看 922关注 0票数 0

加载预训练的PyTorch模型文件,当我尝试使用torch.jit.script运行它时,我得到以下错误,当我试图从pytorch.org运行内置的预训练模型时,它工作得非常好。(例如( 链接到示例代码)但是为自定义构建的预训练模型(包含预先训练的模型权重的Git回购) (预训练权)抛出错误

代码语言:javascript
复制
encoder = enCoder()
encoder = torch.nn.DataParallel(encoder)
encoder.load_state_dict(weights['state_dict'])
encoder.eval()

torchscript_model = torch.jit.script(encoder)

# Error
---------------------------------------------------------------------------
NotSupportedError                         Traceback (most recent call last)
[<ipython-input-30-1d9f30e14902>](https://localhost:8080/#) in <module>()
      1 # torch.quantization.convert(encoder, inplace=True)
      2 
----> 3 torchscript_model = torch.jit.script(encoder)

8 frames
[/usr/local/lib/python3.7/dist-packages/torch/jit/_script.py](https://localhost:8080/#) in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1256         obj = call_prepare_scriptable_func(obj)
   1257         return torch.jit._recursive.create_script_module(
-> 1258             obj, torch.jit._recursive.infer_methods_to_compile
   1259         )
   1260 

[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    449     if not is_tracing:
    450         AttributeTypeIsSupportedChecker().check(nn_module)
--> 451     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    452 
    453 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    461     """
    462     cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
--> 463     method_stubs = stubs_fn(nn_module)
    464     property_stubs = get_property_stubs(nn_module)
    465     hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)

[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in infer_methods_to_compile(nn_module)
    730     stubs = []
    731     for method in uniqued_methods:
--> 732         stubs.append(make_stub_from_method(nn_module, method))
    733     return overload_stubs + stubs
    734 

[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in make_stub_from_method(nn_module, method_name)
     64     # In this case, the actual function object will have the name `_forward`,
     65     # even though we requested a stub for `forward`.
---> 66     return make_stub(func, method_name)
     67 
     68 

[/usr/local/lib/python3.7/dist-packages/torch/jit/_recursive.py](https://localhost:8080/#) in make_stub(func, name)
     49 def make_stub(func, name):
     50     rcb = _jit_internal.createResolutionCallbackFromClosure(func)
---> 51     ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
     52     return ScriptMethodStub(rcb, ast, func)
     53 

[/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py](https://localhost:8080/#) in get_jit_def(fn, def_name, self_name, is_classmethod)
    262         pdt_arg_types = type_trace_db.get_args_types(qualname)
    263 
--> 264     return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
    265 
    266 # TODO: more robust handling of recognizing ignore context manager

[/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py](https://localhost:8080/#) in build_def(ctx, py_def, type_line, def_name, self_name, pdt_arg_types)
    300                        py_def.col_offset + len("def"))
    301 
--> 302     param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
    303     return_type = None
    304     if getattr(py_def, 'returns', None) is not None:

[/usr/local/lib/python3.7/dist-packages/torch/jit/frontend.py](https://localhost:8080/#) in build_param_list(ctx, py_args, self_name, pdt_arg_types)
    324         expr = py_args.kwarg
    325         ctx_range = ctx.make_range(expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg))
--> 326         raise NotSupportedError(ctx_range, _vararg_kwarg_err)
    327     if py_args.vararg is not None:
    328         expr = py_args.vararg

NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py", line 147
    def forward(self, *inputs, **kwargs):
                                ~~~~~~~ <--- HERE
        with torch.autograd.profiler.record_function("DataParallel.forward"):
            if not self.device_ids:
`
    
### Versions

Collecting environment information...
PyTorch version: 1.10.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Libc version: glibc-2.26

Python version: 3.7.13 (default, Mar 16 2022, 17:37:17)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: False
CUDA runtime version: 11.1.105
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.10.0+cu111
[pip3] torchaudio==0.10.0+cu111
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.11.0
[pip3] torchvision==0.11.1+cu111
[conda] Could not collect

任何帮助都是非常感谢的。

EN

回答 1

Stack Overflow用户

发布于 2022-08-05 02:02:32

torch.jit.script通过解析来自module.forward()的python源代码来创建一个ScriptFunction(一个带有图形的函数)。如果您的模块包含了python解析器无法支持的一些语法,它将失败。特别是对于不包含静态类型的对象。

使用torch.jit.trace可以避免此类问题。它在op调用过程中创建图形(c++方式)。它将永远不会失败,但不能处理如果-否则分支的情况。如果您有分支,您应该跟踪它的每一个迭代,导致2向前,1向后,在每个训练过程。使用无brach模型,您可以重用跟踪的ScriptFunction。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72003175

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档