我用Python导出了一个标准的、经过预先训练的PyTorch模型,其代码如下:
import torch
import torchvision
model = torchvision.models.resnext50_32x4d(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("traced_resnext50_32x4d_model.pt")我现在正试图通过Module torch::jit::load(std::istream &in, c10::optional<c10::Device> device = c10::nullopt)函数LibTorch加载这个模型。
它需要一个std::istream,我将.pt模型加载为具有一个大小的简单const char*缓冲区。由于不推荐使用strstreambuf,所以我需要使用自定义的streambuf:
class MemReader : public std::streambuf {
public:
MemReader(const char* data, size_t size);
private:
int_type underflow();
int_type uflow();
int_type pbackfail(int_type ch);
std::streamsize showmanyc();
const char* const begin_;
const char* const end_;
const char* current_;
};
MemReader::MemReader(const char* data, size_t size) :
begin_(data),
end_(data + size),
current_(data)
{}
MemReader::int_type MemReader::underflow() {
if (current_ == end_) {
return traits_type::eof();
}
return traits_type::to_int_type(*current_);
}
MemReader::int_type MemReader::uflow() {
if (current_ == end_) {
return traits_type::eof();
}
return traits_type::to_int_type(*current_++);
}
MemReader::int_type MemReader::pbackfail(int_type ch) {
if (current_ == begin_ || (ch != traits_type::eof() && ch != current_[-1])) {
return traits_type::eof();
}
return traits_type::to_int_type(*--current_);
}
std::streamsize MemReader::showmanyc() {
return end_ - current_;
}我验证了.pt模型文件和包装在std::istream(&mr)中的MemReader mr都包含完全相同的数据。
但是,当使用此代码加载模型时:
utils::MemReader mr(modelScript, modelScriptSize);
std::istream is(&mr);
mod_ = torch::jit::load(is, device_);我知道这个错误:
istream reader failed: checking archive.
Exception raised from validate at /tmp/pytorch/pytorch/caffe2/serialize/istream_adapter.cc:32 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6c (0x7fdce0d4c7ac in /usr/local/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x7fdce0d18866 in /usr/local/lib/libc10.so)
frame #2: caffe2::serialize::IStreamAdapter::validate(char const*) const + 0x17b (0x7fdce3635beb in /usr/local/lib/libtorch_cpu.so)
frame #3: caffe2::serialize::IStreamAdapter::read(unsigned long, void*, unsigned long, char const*) const + 0x41 (0x7fdce3635d21 in /usr/local/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x3f9c02b (0x7fdce4d2402b in /usr/local/lib/libtorch_cpu.so)
frame #5: torch::jit::load(std::shared_ptr<caffe2::serialize::ReadAdapterInterface>, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x6c (0x7fdce4d20a5c in /usr/local/lib/libtorch_cpu.so)
frame #6: torch::jit::load(std::istream&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0xc2 (0x7fdce4d22aa2 in /usr/local/lib/libtorch_cpu.so)
frame #7: torch::jit::load(std::istream&, c10::optional<c10::Device>) + 0x6a (0x7fdce4d22b8a in /usr/local/lib/libtorch_cpu.so)我知道您需要使用与您的PyTorch版本相同的LibTorch版本来跟踪模型。但是,当我通过文件路径加载模型时:
mod_ = torch::jit::load("/data/bin/traced_resnext50_32x4d_model.pt", device_);它起作用了!
有人知道这上面是什么吗?
发布于 2022-02-16 15:29:48
我现在用不推荐的std::strstreambuf替换了我的std::strstreambuf,现在它开始工作了。虽然它们在通过get(char)读取时产生相同的输出,但它们之间肯定有一些不同。
我不知道足够的C++,说出到底是什么问题在这里。
https://stackoverflow.com/questions/71083901
复制相似问题