首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >torchscript与C++接口中的错误输入形状误差

torchscript与C++接口中的错误输入形状误差
EN

Stack Overflow用户
提问于 2021-12-28 13:22:41
回答 1查看 136关注 0票数 0

我正在尝试通过接口lib呼机和OpenCV来预测使用Yolov5模型的类。我使用的权重是yolov5s.pt。源代码是

代码语言:javascript
复制
cv::Mat image = file->input_image(); // read image and resize into 640x640

auto tensor = torch::from_blob(image.data, {image.rows,image.cols,3}, torch::kFloat);
tensor = tensor.view({1,640,640,3});
std::cout << tensor.sizes() << std::endl;

try {
  auto output = model.forward({tensor}).toTensor();
  std::cout << output.sizes() << std::endl;
} catch (std::runtime_error & e) {
  std::cerr << "[X] Error: " << e.what() << std::endl;
  return;
}

错误消息

代码语言:javascript
复制
RuntimeError: Given groups=1, weight of size [32, 3, 6, 6], expected input[1, 640, 640, 3] to have 3 
channels, but got 640 channels instead

回溯

代码语言:javascript
复制
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/models/yolo.py", line 59, in forward
    model23 = self.model
    _0 = getattr(model23, "0")
    _25 = (_2).forward((_1).forward((_0).forward(x, ), ), )
                                     ~~~~~~~~~~~ <--- HERE
    _26 = (_4).forward((_3).forward(_25, ), )
    _27 = (_6).forward((_5).forward(_26, ), )
  File "code/__torch__/models/common.py", line 12, in forward
    act = self.act
    conv = self.conv
    _0 = (act).forward((conv).forward(x, ), )
                        ~~~~~~~~~~~~~ <--- HERE
    return _0
class C3(Module):
  File "code/__torch__/torch/nn/modules/conv.py", line 12, in forward
    bias = self.bias
    weight = self.weight
    x0 = torch._convolution(x, weight, bias, [2, 2], [2, 2], [1, 1], False, [0, 0], 1, False, False, True, True)
         ~~~~~~~~~~~~~~~~~~ <--- HERE
    return x0
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-12-28 14:08:21

解决办法很简单。我觉得太尴尬了,但你不应该。

这是解决办法

代码语言:javascript
复制
// forgot to add these both lines
// the yolov5 is expects [BATCH, CHANNEL, WIDTH, HEIGHT]
tensor = tensor.permute({2,0,1});
tensor = tensor = tensor.unsqueeze(0);

std::cout << tensor.sizes() << std::endl;
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70507682

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档