首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在Onnx中使用Pytorch模型的函数来获得输出而不是model.forward()函数

如何在Onnx中使用Pytorch模型的函数来获得输出而不是model.forward()函数
EN

Stack Overflow用户
提问于 2022-07-02 17:30:27
回答 1查看 318关注 0票数 0

博士:我如何使用model.whatever_function(input) model.forward(input) 而不是 onnxruntimemodel.forward(input)

我使用剪辑嵌入为我的图像和文本创建嵌入如下:

代码来自正式的git合并。

代码语言:javascript
复制
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

import clip
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model, preprocess = clip.load("RN50", device=device) # Load any model
model = model.eval() # Inference Only

img_size = model.visual.input_resolution
dummy_image = torch.randn(10, 3, img_size, img_size).to(device)
image_embedding = model.encode_image(dummy_image).to(device))

dummy_texts = clip.tokenize(["quick brown fox", "lorem ipsum"]).to(device)
model.encode_text(dummy_texts)

给出两个加载模型的[Batch, 1024]张量,效果都很好。

现在,我在Onnx中将我的模型量化为:

代码语言:javascript
复制
model.forward(dummy_image,dummy_texts) # Original CLIP result (1)

torch.onnx.export(model, (dummy_image, dummy_texts), "model.onnx", export_params=True,
  input_names=["IMAGE", "TEXT"],
  output_names=["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"],
  opset_version=14,
  dynamic_axes={
      "IMAGE": {
          0: "image_batch_size",
      },
      "TEXT": {
          0: "text_batch_size",
      },
      "LOGITS_PER_IMAGE": {
          0: "image_batch_size",
          1: "text_batch_size",
      },
      "LOGITS_PER_TEXT": {
          0: "text_batch_size",
          1: "image_batch_size",
      },
  }
)

模型就被保存了。

当我将模型测试为:

代码语言:javascript
复制
# Now run onnxruntime to verify
import onnxruntime as ort

ort_sess = ort.InferenceSession("model.onnx")
result=ort_sess.run(["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"], 
  {"IMAGE": dummy_image.numpy(), "TEXT": dummy_texts.numpy()})

它给出了一个长度为2的列表,每个图像和文本都有一个,result[0]具有[Batch,2]的形状。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-03 09:02:00

如果模块上的encode_image没有调用forward,那么在导出到Onnx之前,没有什么可以阻止您重写forward

代码语言:javascript
复制
>>> model.forward = model.encode_image
>>> torch.onnx.export(model, (dummy_image, dummy_texts), "model.onnx", ...))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72841141

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档