首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将PyTorch模型转换为TorchScript时出错

将PyTorch模型转换为TorchScript时出错
EN

Stack Overflow用户
提问于 2018-12-18 01:23:34
回答 1查看 1.5K关注 0票数 2

我在试着跟着PyTorch guide to load models in C++走。

下面的示例代码可以工作:

代码语言:javascript
复制
import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

但是,当尝试其他网络时,例如squeezenet (或alexnet),我的代码失败:

代码语言:javascript
复制
sq = torchvision.models.squeezenet1_0(pretrained=True)
traced_script_module = torch.jit.trace(sq, example) 

>> traced_script_module = torch.jit.trace(sq, example)                                      
/home/fabio/.local/lib/python3.6/site-packages/torch/jit/__init__.py:642: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function.
 Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 785] (3.1476082801818848 vs. 3.945478677749634) and 999 other locations (100.00%)
  _check_trace([example_inputs], func, executor_options, module, check_tolerance, _force_outplace)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-12-18 01:25:35

我刚刚发现从torchvision.models加载的模型在默认情况下处于训练模式。AlexNet和SqueezeNet都有Dropout层,如果在训练模式下,则使推理不确定。只需更改为eval模式即可修复此问题:

代码语言:javascript
复制
sq = torchvision.models.squeezenet1_0(pretrained=True)
sq.eval()
traced_script_module = torch.jit.trace(sq, example) 
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53820175

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档