首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >从InferenceSession导出的故障ONNX PyTorch ONNX模型

从InferenceSession导出的故障ONNX PyTorch ONNX模型
EN

Stack Overflow用户
提问于 2021-03-18 17:36:37
回答 1查看 818关注 0票数 1

我试图将自定义的PyTorch模型导出到ONNX以执行推理,但没有成功.这里的棘手之处在于,我试图使用基于脚本的导出程序,如示例这里中所示,以便从我的模型中调用一个函数。

我可以在没有任何抱怨的情况下导出模型,但是当尝试启动InferenceSession时,我会得到以下错误:

代码语言:javascript
复制
Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ner.onnx failed:Type Error: Type parameter (T) bound to different types (tensor(int64) and tensor(float) in node (Concat_1260).

我试图找出造成这个问题的根本原因,它似乎是通过在以下函数中使用torch.matmul()而产生的(非常讨厌的原因是,我试图只使用pytorch操作符):

代码语言:javascript
复制
@torch.jit.script
def valid_sequence_output(sequence_output, valid_mask):
    X = torch.where(valid_mask.unsqueeze(-1) == 1, sequence_output, torch.zeros_like(sequence_output))
    bs, max_len, _ = X.shape

    tu = torch.unique(torch.nonzero(X)[:, :2], dim=0)
    batch_axis = tu[:, 0]
    rows_axis = tu[:, 1]

    a = torch.arange(bs).repeat(batch_axis.shape).reshape(batch_axis.shape[0], -1)
    a = torch.transpose(a, 0, 1)

    T = torch.cumsum(torch.where(batch_axis == a, torch.ones_like(a), torch.zeros_like(a)), dim=1) - 1
    cols_axis = T[batch_axis, torch.arange(batch_axis.shape[0])]

    A = torch.zeros((bs, max_len, max_len))
    A[(batch_axis, cols_axis, rows_axis)] = 1.0

    valid_output = torch.matmul(A, X)
    valid_attention_mask = torch.where(valid_output[:, :, 0] != 0, torch.ones_like(valid_mask),
                                       torch.zeros_like(valid_mask))
    return valid_output, valid_attention_mask

似乎不支持torch.matmul (根据文档),所以我尝试了一些解决方法(例如A.matmul(X)torch.baddbmm),但我仍然遇到了同样的问题.

任何关于如何纠正这种行为的建议都是很棒的:D感谢你的帮助!

EN

回答 1

Stack Overflow用户

发布于 2021-04-12 09:59:29

这指向一个模型转换问题。请打开一个针对火炬出口商特征的问题。一个类型(T)必须绑定到同一类型才能使模型有效,并且ORT基本上是在抱怨这一点。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66696275

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档