首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何删除/替换现有模型中的层?

如何删除/替换现有模型中的层?
EN

Stack Overflow用户
提问于 2021-09-29 12:44:58
回答 1查看 1.6K关注 0票数 0
  1. 如何从预培训网络中删除某些层(例如删除单个ReLU激活层)?
  2. 如何按类型替换预培训网络中的某些层(例如用AvrPool替换MaxPool2d )?
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-09-29 13:34:34

假设您知道模型的结构,您可以:

代码语言:javascript
复制
>>> model = torchvision.models(pretrained=True)

  • 选择一个子模块并与它交互,就像与任何其他nn.Module一样。这将取决于您的模型的实现。例如,子模块通常可以通过属性(例如model.features)访问,但是情况并不总是如此,例如,nn.Sequential使用索引:model.features[18]来选择relu激活之一。还请注意:并非所有层都在nn.Module内部注册,非参数函数(如大多数激活函数)可以通过函数方法直接应用于模块的前面。

  • 对于给定的nn.Module m,可以使用type(m).__name__提取其层名。一种规范的方法是过滤model.modules的层,只保留最大的池层,然后用model.named_modules()中的平均池k= k,m替换那些平均池层。如果输入(M).__name__ == 'MaxPool2d‘

我们可以提取每个层的父模块名称:

最大池= k.split('.')对于k,m在model.named_modules()中.如果输入(M).__name__ ==‘MaxPool2d’功能‘,'4',’特征‘,'9',’特征‘,'16',’特征‘,'23',’特征‘,'30']

在这里,它们都来自同一个父模块model.features。最后,我们可以获取层引用以覆盖它们的值:

对于*父母,k在最大池:.Model.get_submodule(‘.联接(父))int(K)= nn.AvgPool2d(2,2)

其结果是:

VGG( (特征):顺序(0):Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (1):ReLU(inplace=True) (2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (3):ReLU(inplace=True) (4):AvgPool2d(kernel_size=2,stride=2,padding=0) (5):Conv2d(64个)kernel_size=(3,3),stride=(1,1),padding=(1,1) (6):ReLU(inplace=True) (7):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (8):ReLU(inplace=True) (9):AvgPool2d(kernel_size=2,stride=2,padding=0) (10):Conv2d(128,256,kernel_size=(3,3),Conv2d(1,1),padding=(1,1) (11):ReLU(inplace=True) (12):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1) (13):ReLU(inplace=True) (14):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (15):ReLU(inplace=True) (16):AvgPool2d(kernel_size=2,inplace=True,)padding=0) (17):Conv2d(256,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (18):ReLU(inplace=True) (19):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1) (20):ReLU(inplace=True) (21):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))1) (22):ReLU(inplace=True) (23):AvgPool2d(kernel_size=2,stride=2,padding=0) (24):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1) (25):ReLU(inplace=True) (26):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1) (27):ReLU(inplace=True) (28):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) (29):ReLU(inplace=True) (30):AvgPool2d(kernel_size=2,stride=2,padding=0) (avgpool):AdaptiveAvgPool2d(output_size=(7,7)) (分类器):顺序(0):线性(in_features=25088,out_features=4096,) ( bias=True) (1):ReLU(inplace=True) (2):辍学(p=0.5,inplace=False) (3):线性(in_features=4096,out_features=4096,bias=True) (4):ReLU(inplace=True) (5):辍学(p=0.5,inplace=False) (6):线性(in_features=4096,out_features=1000,bias=True) )

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69376651

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档