在我从torchvision.models中获取一个经过预先训练的模型之后,我希望所有的ReLU实例都到register_backward_hook(f)中,如下所示:
for pos, module in self.model.features._modules.items():
for sub_module in module:
if isinstance(module, ReLU):
module.register_backward_hook(f)对我来说,问题是如何在模型中找到所有的ReLU。对于densenet161,ReLU不仅存在于model.features._modules中,而且还存在于自定义的致密层中。model.features._modules['denseblock1'][0]。对于resnet151,ReLU存在于model._modules及其自定义层(如model._modules['layer1'] )中。
有什么方法可以在模型中找到所有ReLU吗?
发布于 2018-10-04 04:46:36
对模型的所有组件进行迭代的一个更优雅的方法是使用modules()方法:
from torch import nn
for module in self.model.modules():
if isinstance(module, nn.ReLU):
module.register_backward_hook(f)如果您不想获得所有子模块,只想获得直接子模块,可以考虑使用children()方法而不是modules()方法。您还可以使用named_modules()方法获取子模块的名称。
https://stackoverflow.com/questions/52637211
复制相似问题