首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何加载.pth文件?

如何加载.pth文件?
EN

Stack Overflow用户
提问于 2021-05-23 08:56:29
回答 1查看 8.2K关注 0票数 0

我从这个存储库中得到了一个pytorch模型,我必须将它转换为tflite。下面是代码:

代码语言:javascript
复制
def get_torch_model(model_path):
    """
    Loads state-dict into model and creates an instance.
    """
    model= torch.load(model_path)
    return model
代码语言:javascript
复制
# Conversion
import torch
from torchvision import transforms

import onnx

import cv2
import numpy as np
import onnx
import tensorflow as tf
import torch
from PIL import Image

import torch.onnx

image, tf_lite_image, sample_input = get_sample_input("crop.jpg")
torch_model = get_torch_model("pose_resnet_152_256x256.pth")

ONNX_FILE = "./m_model.onnx"

在此之前一切都进展顺利。但是当我运行下面的单元格时:

代码语言:javascript
复制
torch.onnx.export(
        model=torch_model,
        args=sample_input,
        f=ONNX_FILE,
        verbose=False,
        export_params=True,
        do_constant_folding=False,  # fold constant values for optimization
        input_names=['input'],
        opset_version=10,
        output_names=['output']
)

onnx_model = onnx.load(ONNX_FILE)

onnx.checker.check_model(onnx_model)

完整的错误日志:

代码语言:javascript
复制
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-33-15df717ec276> in <module>
      8         input_names=['input'],
      9         opset_version=10,
---> 10         output_names=['output']
     11 )
     12 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    274                         do_constant_folding, example_outputs,
    275                         strip_doc_string, dynamic_axes, keep_initializers_as_inputs,
--> 276                         custom_opsets, enable_onnx_checker, use_external_data_format)
    277 
    278 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     92             dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs,
     93             custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker,
---> 94             use_external_data_format=use_external_data_format)
     95 
     96 

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
    677         _set_opset_version(opset_version)
    678         _set_operator_export_type(operator_export_type)
--> 679         with select_model_mode_for_export(model, training):
    680             val_keep_init_as_ip = _decide_keep_init_as_input(keep_initializers_as_inputs,
    681                                                              operator_export_type,

~\anaconda3\envs\py36\lib\contextlib.py in __enter__(self)
     79     def __enter__(self):
     80         try:
---> 81             return next(self.gen)
     82         except StopIteration:
     83             raise RuntimeError("generator didn't yield") from None

~\anaconda3\envs\py36\lib\site-packages\torch\onnx\utils.py in select_model_mode_for_export(model, mode)
     36 def select_model_mode_for_export(model, mode):
     37     if not isinstance(model, torch.jit.ScriptFunction):
---> 38         is_originally_training = model.training
     39 
     40         if mode is None:

AttributeError: 'collections.OrderedDict' object has no attribute 'training'

当我使用torch.onnx.export()时会发生此错误。

请告诉我这里出了什么问题。我没把重量装好吗?如果没有,我如何加载模型?我不知道类或架构细节,所以我如何使用model.load_state_dict()?

代码语言:javascript
复制
  [1]: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
EN

回答 1

Stack Overflow用户

发布于 2021-05-23 10:46:15

pytorch中的.pth二进制文件不存储模型,而只存储其经过训练的权重。您需要import实现模型功能的class (派生class of torch.nn.Module)。一旦您有了这个功能,您就可以加载经过训练的权重,以获得要使用的模型的一个特定实例。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67657926

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档