首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >通过std::istream加载Python在LibTorch中导出的LibTorch模型

通过std::istream加载Python在LibTorch中导出的LibTorch模型
EN

Stack Overflow用户
提问于 2022-02-11 17:06:06
回答 1查看 900关注 0票数 0

我用Python导出了一个标准的、经过预先训练的PyTorch模型,其代码如下:

代码语言:javascript
复制
import torch
import torchvision

model = torchvision.models.resnext50_32x4d(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("traced_resnext50_32x4d_model.pt")

我现在正试图通过Module torch::jit::load(std::istream &in, c10::optional<c10::Device> device = c10::nullopt)函数LibTorch加载这个模型。

它需要一个std::istream,我将.pt模型加载为具有一个大小的简单const char*缓冲区。由于不推荐使用strstreambuf,所以我需要使用自定义的streambuf:

代码语言:javascript
复制
class MemReader : public std::streambuf {
public:
    MemReader(const char* data, size_t size);
 
private:
    int_type underflow();
    int_type uflow();
    int_type pbackfail(int_type ch);
    std::streamsize showmanyc();
 
    const char* const begin_;
    const char* const end_;
    const char* current_;
};

MemReader::MemReader(const char* data, size_t size) : 
    begin_(data), 
    end_(data + size), 
    current_(data) 
{}
 
MemReader::int_type MemReader::underflow() {
    if (current_ == end_) {
        return traits_type::eof();
    }
    return traits_type::to_int_type(*current_);
}
 
MemReader::int_type MemReader::uflow() {
    if (current_ == end_) {
        return traits_type::eof();
    }
    return traits_type::to_int_type(*current_++);
}
 
MemReader::int_type MemReader::pbackfail(int_type ch) {
    if (current_ == begin_ || (ch != traits_type::eof() && ch != current_[-1])) {
        return traits_type::eof();
    }
    return traits_type::to_int_type(*--current_);
}
 
std::streamsize MemReader::showmanyc() {
    return end_ - current_;
}

我验证了.pt模型文件和包装在std::istream(&mr)中的MemReader mr都包含完全相同的数据

但是,当使用此代码加载模型时:

代码语言:javascript
复制
utils::MemReader mr(modelScript, modelScriptSize);
std::istream is(&mr);
mod_ = torch::jit::load(is, device_);

我知道这个错误:

代码语言:javascript
复制
istream reader failed: checking archive.
Exception raised from validate at /tmp/pytorch/pytorch/caffe2/serialize/istream_adapter.cc:32 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7fdce0d4c7ac in /usr/local/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x7fdce0d18866 in /usr/local/lib/libc10.so)
frame #2: caffe2::serialize::IStreamAdapter::validate(char const*) const + 0x17b (0x7fdce3635beb in /usr/local/lib/libtorch_cpu.so)
frame #3: caffe2::serialize::IStreamAdapter::read(unsigned long, void*, unsigned long, char const*) const + 0x41 (0x7fdce3635d21 in /usr/local/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x3f9c02b (0x7fdce4d2402b in /usr/local/lib/libtorch_cpu.so)
frame #5: torch::jit::load(std::shared_ptr<caffe2::serialize::ReadAdapterInterface>, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x6c (0x7fdce4d20a5c in /usr/local/lib/libtorch_cpu.so)
frame #6: torch::jit::load(std::istream&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0xc2 (0x7fdce4d22aa2 in /usr/local/lib/libtorch_cpu.so)
frame #7: torch::jit::load(std::istream&, c10::optional<c10::Device>) + 0x6a (0x7fdce4d22b8a in /usr/local/lib/libtorch_cpu.so)

我知道您需要使用与您的PyTorch版本相同的LibTorch版本来跟踪模型。但是,当我通过文件路径加载模型时:

代码语言:javascript
复制
mod_ = torch::jit::load("/data/bin/traced_resnext50_32x4d_model.pt", device_);

它起作用了!

有人知道这上面是什么吗?

EN

回答 1

Stack Overflow用户

发布于 2022-02-16 15:29:48

我现在用不推荐的std::strstreambuf替换了我的std::strstreambuf,现在它开始工作了。虽然它们在通过get(char)读取时产生相同的输出,但它们之间肯定有一些不同。

我不知道足够的C++,说出到底是什么问题在这里。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71083901

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档