首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用pytorch-lightning进行简单预测的示例

使用pytorch-lightning进行简单预测的示例
EN

Stack Overflow用户
提问于 2020-05-03 06:04:45
回答 1查看 6.1K关注 0票数 3

我有一个现有的模型,在这个模型中,我加载了一些预先训练好的权重,然后在pytorch中进行预测(一次一个图像)。我正在尝试将它基本上转换为pytorch闪电模块,并对一些事情感到困惑。

因此,目前,我的模型__init__方法如下所示:

代码语言:javascript
复制
self._load_config_file(cfg_file)
# just creates the pytorch network
self.create_network()  

self.load_weights(weights_file)

self.cuda(device=0)  # assumes GPU and uses one. This is probably suboptimal
self.eval()  # prediction mode

我可以从lightning文档中收集到的信息,除了不执行cuda()调用之外,我几乎可以做同样的事情。所以就像这样:

代码语言:javascript
复制
self.create_network()

self.load_weights(weights_file)
self.freeze()  # prediction mode

所以,我的第一个问题是,这是否是使用闪电的正确方式?lightning如何知道它是否需要使用GPU?我猜这需要在某个地方指定。

现在,对于预测,我有以下设置:

代码语言:javascript
复制
def infer(frame):
    img = transform(frame)  # apply some transformation to the input
    img = torch.from_numpy(img).float().unsqueeze(0).cuda(device=0)
    with torch.no_grad():
        output = self.__call__(Variable(img)).data.cpu().numpy()
    return output

这就是让我感到困惑的地方。我需要重写哪些函数才能做出兼容闪电的预测?

而且,目前,输入是以numpy数组的形式出现的。这是不是可以通过lightning模块实现,或者总是需要使用某种数据加载器?

在某种程度上,我想扩展这个模型实现来做训练,所以我想确保我做的是正确的,但是虽然大多数例子都集中在训练模型上,但是在生产时对单个图像/数据点进行预测的简单例子可能会很有用。

我在带有cuda 10.1的GPU上使用0.7.5和pytorch 1.4.0

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-07-25 21:29:13

LightningModuletorch.nn.Module的子类,因此同一模型类既可用于推理,也可用于训练。因此,您可能应该在__init__外部调用cuda()eval()方法。

由于它只是一个幕后的nn.Module,一旦您加载了权重,就不需要覆盖任何方法来执行推理,只需调用模型实例即可。下面是一个你可以使用的玩具示例:

代码语言:javascript
复制
import torchvision.models as models
from pytorch_lightning.core import LightningModule

class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet18(pretrained=True, progress=False)
    
    def forward(self, x):
        return self.resnet(x)

model = MyModel().eval().cuda(device=0)

然后,要实际运行推理,您不需要方法,只需执行以下操作:

代码语言:javascript
复制
for frame in video:
    img = transform(frame)
    img = torch.from_numpy(img).float().unsqueeze(0).cuda(0)
    output = model(img).data.cpu().numpy()
    # Do something with the output

PyTorchLighting的主要好处是,您还可以通过在同一个类上实现training_step()configure_optimizers()train_dataloader()来使用该类进行训练。您可以在PyTorchLightning docs中找到一个简单的示例。

票数 8
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61566919

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档