Pytorch文档提供了一个应用MiDaS单目深度估计网络进行深度提取的concise way。但是,我应该如何修改它们的代码,以便在某些中间层提取网络表示呢?我知道我可以从github下载模型并修改forward函数来返回我想要的东西,但我对最简单的解决方案感兴趣,让外部代码保持原样。
我知道对模型类进行子类化并编写自己的转发函数,比如here,但我不知道如何在代码中访问该类。模型实例是用midas = torch.hub.load("intel-isl/MiDaS", model_type)直接创建的。也许使用前向钩子的例子会更简单。
发布于 2021-07-15 22:52:50
正如您所说,在nn.Module上使用前向钩子是最简单的方法。考虑一下文档:https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook
基本上,您只需定义一个函数,该函数接受三个输入(module, input, output),然后对这些数据做任何您想做的事情。要找到您想要放置钩子的模块,您显然需要熟悉模型的结构。您只需使用print(midas)即可获得所有可用模块的精美打印表示。我只是随机选择了一个,并使用print()函数作为钩子:
midas.pretrained.model.blocks[3].mlp.fc2.register_forward_hook(print)这意味着每当我们调用midas(some_input)时,钩子(在本例中为print)将使用相应的参数进行调用。当然,除了print之外,你还可以编写一个函数,将这些文件保存到一个可以从外部访问的列表中,或者将它们写到一个文件中等等。
https://stackoverflow.com/questions/68392911
复制相似问题