首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在火炬视觉模型中查找所有ReLU层

在火炬视觉模型中查找所有ReLU层
EN

Stack Overflow用户
提问于 2018-10-04 00:07:16
回答 1查看 204关注 0票数 1

在我从torchvision.models中获取一个经过预先训练的模型之后,我希望所有的ReLU实例都到register_backward_hook(f)中,如下所示:

代码语言:javascript
复制
for pos, module in self.model.features._modules.items():
    for sub_module in module:
        if isinstance(module, ReLU):
            module.register_backward_hook(f)

对我来说,问题是如何在模型中找到所有的ReLU。对于densenet161ReLU不仅存在于model.features._modules中,而且还存在于自定义的致密层中。model.features._modules['denseblock1'][0]。对于resnet151ReLU存在于model._modules及其自定义层(如model._modules['layer1'] )中。

有什么方法可以在模型中找到所有ReLU吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-10-04 04:46:36

对模型的所有组件进行迭代的一个更优雅的方法是使用modules()方法:

代码语言:javascript
复制
from torch import nn

for module in self.model.modules():
  if isinstance(module, nn.ReLU):
    module.register_backward_hook(f)

如果您不想获得所有子模块,只想获得直接子模块,可以考虑使用children()方法而不是modules()方法。您还可以使用named_modules()方法获取子模块的名称。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52637211

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档