首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在pytorch中提取MiDaS神经网络的中间表示?

在pytorch中提取MiDaS神经网络的中间表示?
EN

Stack Overflow用户
提问于 2021-07-15 19:23:30
回答 1查看 44关注 0票数 1

Pytorch文档提供了一个应用MiDaS单目深度估计网络进行深度提取的concise way。但是,我应该如何修改它们的代码,以便在某些中间层提取网络表示呢?我知道我可以从github下载模型并修改forward函数来返回我想要的东西,但我对最简单的解决方案感兴趣,让外部代码保持原样。

我知道对模型类进行子类化并编写自己的转发函数,比如here,但我不知道如何在代码中访问该类。模型实例是用midas = torch.hub.load("intel-isl/MiDaS", model_type)直接创建的。也许使用前向钩子的例子会更简单。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-15 22:52:50

正如您所说,在nn.Module上使用前向钩子是最简单的方法。考虑一下文档:https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook

基本上,您只需定义一个函数,该函数接受三个输入(module, input, output),然后对这些数据做任何您想做的事情。要找到您想要放置钩子的模块,您显然需要熟悉模型的结构。您只需使用print(midas)即可获得所有可用模块的精美打印表示。我只是随机选择了一个,并使用print()函数作为钩子:

代码语言:javascript
复制
midas.pretrained.model.blocks[3].mlp.fc2.register_forward_hook(print)

这意味着每当我们调用midas(some_input)时,钩子(在本例中为print)将使用相应的参数进行调用。当然,除了print之外,你还可以编写一个函数,将这些文件保存到一个可以从外部访问的列表中,或者将它们写到一个文件中等等。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68392911

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档