首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如果我希望PyTorch模块可以加载,我应该如何保存它的模型

如果我希望PyTorch模块可以加载,我应该如何保存它的模型
EN

Stack Overflow用户
提问于 2017-08-29 01:53:27
回答 1查看 4.6K关注 0票数 6

我用PyTorch训练了一个简单的分类模型,然后用opencv3.3加载它,但是它抛出异常并说

/home/ramsus/Qt/3rdLibs/opencv/modules/dnn/src/torch/torch_importer.cpp,OpenCV错误:函数/特性未在readObject中实现(不支持Lua类型),文件第797行OpenCV错误:(-213)函数readObject中不支持的Lua类型

模型定义

代码语言:javascript
复制
class conv_block(nn.Module):
    def __init__(self, in_filter, out_filter, kernel):
        super(conv_block, self).__init__()

        self.conv1 = nn.Conv2d(in_filter, out_filter, kernel, 1, (kernel - 1)//2)
        self.batchnorm = nn.BatchNorm2d(out_filter)
        self.maxpool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = F.relu(x)
        x = self.maxpool(x)

        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = conv_block(3, 6, 3)
        self.conv2 = conv_block(6, 16, 3)
        self.fc1 = nn.Linear(16 * 8 * 8, 120)
        self.bn1 = nn.BatchNorm1d(120)
        self.fc2 = nn.Linear(120, 84)
        self.bn2 = nn.BatchNorm1d(84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

该模型只使用Conv2d、ReLU、BatchNorm2d、MaxPool2d和线性层,每一层都得到opencv3.3的支持。

我用state_dict保存它

代码语言:javascript
复制
torch.save(net.state_dict(), 'cifar10_model')

通过c++加载它

代码语言:javascript
复制
std::string const model_file("/home/some_folder/cifar10_model");

std::cout<<"read net from torch"<<std::endl;
dnn::Net net = dnn::readNetFromTorch(model_file);

我想我是用错误的方式保存模型的,为了使用PyTorch加载OpenCV,保存模型的正确方法是什么?谢谢

编辑:

我使用另一种方式保存模型,但也不能加载它。

代码语言:javascript
复制
torch.save(net, 'cifar10_model.net')

这是虫子吗?还是我做错了什么?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-09-26 02:48:42

我找到了答案,opencv3.3不支持PyTorch (https://github.com/pytorch/pytorch),但支持py手电筒(https://github.com/hughperkins/pytorch),这是一个很大的惊喜,我从来不知道有另一个版本的py手电筒存在(看起来像一个死的项目,很长时间没有更新),我希望他们能提到他们支持的维基。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45929573

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档