首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >不能替换Densenet121上的分类器

不能替换Densenet121上的分类器
EN

Stack Overflow用户
提问于 2019-09-05 13:51:22
回答 2查看 1.5K关注 0票数 0

我正在尝试使用这个github DenseNet121模型(https://github.com/gaetandi/cheXpert.git)进行一些转移学习。我遇到了一些问题,将分类层的大小从14个调整到2个。

github代码的相关部分是:

代码语言:javascript
复制
class DenseNet121(nn.Module):
    """Model modified.
    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.
    """
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )
def forward(self, x):
    x = self.densenet121(x)
    return x

我装载并输入:

代码语言:javascript
复制
# initialize and load the model
model = DenseNet121(nnClassCount).cuda()
model = torch.nn.DataParallel(model).cuda()
modeldict = torch.load("model_ones_3epoch_densenet.tar")
model.load_state_dict(modeldict['state_dict'])

看起来,DenseNet不会将层划分成子层,因此model = nn.Sequential(*list(modelRes.children())[:-1])将无法工作。

model.classifier = nn.Linear(1024, 2)似乎在默认的DenseNets上工作,但是在这里使用修改后的分类器(附加的sigmoid函数),它只是添加了一个额外的分类器层,而没有替换原来的分类器层。

我试过了

代码语言:javascript
复制
model.classifier = nn.Sequential(
    nn.Linear(1024, dset_classes_number), 
    nn.Sigmoid()
)

但我有相同的添加而不是替换的分类器问题:

代码语言:javascript
复制
...
      )
      (classifier): Sequential(
        (0): Linear(in_features=1024, out_features=14, bias=True)
        (1): Sigmoid()
      )
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=2, bias=True)
    (1): Sigmoid()
  )
)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-09-05 19:18:32

如果要替换作为classifier成员的densenet121内部的model,则需要分配

代码语言:javascript
复制
model.densenet121.classifier = nn.Sequential(...)
票数 0
EN

Stack Overflow用户

发布于 2022-08-03 16:24:55

如果我理解您的问题,下面的代码将解决

代码语言:javascript
复制
import torchvision.models as models
import torch
from torch import nn

import numpy as np

np.random.seed(0)
torch.manual_seed(0)




densenet121 = models.densenet121(pretrained=True)

for param in densenet121.parameters():
    param.requires_grad = False

densenet121.classifier = nn.Sequential(
    nn.Linear(1024, 14),
    nn.ReLU(),
    nn.Dropout(0.4),


    nn.Linear(14, 2),
)


densenet121.cuda()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57807050

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档