首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将Detectron2模型转换为torchscript

将Detectron2模型转换为torchscript
EN

Stack Overflow用户
提问于 2022-09-06 08:50:15
回答 1查看 302关注 0票数 0

我想将'COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml detectron2模型转换为torchscript。我使用了torc,下面给出了我的代码。

代码语言:javascript
复制
import cv2

import numpy as np

import torch
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.modeling import build_model
from detectron2.export.flatten import TracingAdapter
import os

ModelPath='/home/jayasanka/working_files/create_torchsript/model.pt'
with open('savepic.npy', 'rb') as f:
    image = np.load(f)

#-------------------------------------------------------------------------------------

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))

cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # your number of classes + 1

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, ModelPath)

cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.60  # set the testing threshold for this model

predictor = DefaultPredictor(cfg)

我使用了TracingAdapter和跟踪函数,我很清楚这背后的概念是什么。

代码语言:javascript
复制
# im = cv2.imread(image)
im = torch.tensor(image)

def inference_func(model, image):
    inputs= [{"image": image}]
    return model.inference(inputs, do_postprocess=False)[0]

wrapper= TracingAdapter(predictor, im, inference_func)
wrapper.eval()
traced_script_module= torch.jit.trace(wrapper, (im,))
traced_script_module.save("torchscript.pt")

它给出了下面的错误。

代码语言:javascript
复制
Traceback (most recent call last):
  File "script.py", line 49, in <module>
    traced_script_module= torch.jit.trace(wrapper, (im,))
  File "/home/jayasanka/anaconda3/envs/vha/lib/python3.7/site-packages/torch/jit/_trace.py", line 744, in trace
    _module_class,
  File "/home/jayasanka/anaconda3/envs/vha/lib/python3.7/site-packages/torch/jit/_trace.py", line 959, in trace_module
    argument_names,
  File "/home/jayasanka/anaconda3/envs/vha/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jayasanka/anaconda3/envs/vha/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/jayasanka/anaconda3/envs/vha/lib/python3.7/site-packages/detectron2/export/flatten.py", line 294, in forward
    outputs = self.inference_func(self.model, *inputs_orig_format)
  File "script.py", line 44, in inference_func
    return model.inference(inputs, do_postprocess=False)[0]
  File "/home/jayasanka/anaconda3/envs/vha/lib/python3.7/site-packages/yacs/config.py", line 141, in __getattr__
    raise AttributeError(name)
AttributeError: inference

你能帮我弄清楚这个。还有其他方法可以轻松地做到这一点吗?

EN

回答 1

Stack Overflow用户

发布于 2022-09-26 03:52:34

变到

代码语言:javascript
复制
def inference(model, inputs):
    # use do_postprocess=False so it returns ROI mask
    inst = model.inference(inputs, do_postprocess=False)[0]
    return [{"instances": inst}]

isinstance(image, np.ndarray) == True
image_tensor = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
wrapper= TracingAdapter(predictor, inputs=[{"image": image_tensor}], inference_func=inference)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73619217

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档