首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何改变out_features of densenet121模型?

如何改变out_features of densenet121模型?
EN

Stack Overflow用户
提问于 2022-11-30 20:16:42
回答 1查看 21关注 0票数 0

如何改变out_features of densenet121模型?

我正在使用下面的代码来训练模型:

代码语言:javascript
复制
from torch.nn.modules.dropout import Dropout
    
class Densnet121(nn.Module):
    def __init__(self):
        super(Densnet121, self).__init__() 
        self.cnn1 = nn.Conv2d(in_channels=3 , out_channels=64 , kernel_size=3 , stride=1 )
        self.Densenet_121 = models.densenet121(pretrained=True)
        self.gap = AvgPool2d(kernel_size=2, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(1024)
        self.do1 = nn.Dropout(0.25)
        self.linear = nn.Linear(256,256)
        self.bn2 = nn.BatchNorm2d(256)
        self.do2 = nn.Dropout(0.25)
        self.output = nn.Linear(64 * 64 * 64,2)
        self.act = nn.ReLU()
        
    def densenet(self):
        for param in self.Densenet_121.parameters():
            param.requires_grad = False
        self.Densenet_121.classifier = nn.Linear(1024, 1024)
        return self.Densenet_121
    
    def forward(self, x):
        img = self.act(self.cnn1(x))
        img = self.densenet(img)      
    
        img = self.gap(img)
        img = self.bn1(img)
        img = self.do1(img)
        img = self.linear(img)
        img = self.bn2(img)
        img = self.do2(img)
        img = torch.flatten(img, 1)
        img = self.output(img)
    
        return img

在培训此模型时,我面临以下错误:

RuntimeError:给定groups=1,大小为64,3,7,7的重量,预期input64,64,62,62将有3个通道,而不是64个通道

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-12-01 08:51:38

您的第一个conv层输出形状(b, 64, h, w)的张量,而下面的层,密度集模型需要3个通道。由此引发的错误是:

“预期输入.有3通道,但获得64通道而不是

不幸的是,这个值是在Densenet类的源代码中硬编码的,参见参考文献

然而,一种解决方法是在密度网初始化后覆盖第一个卷积层。像这样的事情应该有效:

代码语言:javascript
复制
# First gather the conv layer specs
conv = self.Densenet_121.features.conv0
kwargs = {k: getattr(conv, k) for k in 
   ('out_channels', 'stride', 'kernel_size', 'padding', 'bias')}

# overwrite with identical specs with new in_channels
model.features.conv0 = nn.Conv2d(in_channels=64, **kwargs)    

或者,你可以这样做:

代码语言:javascript
复制
w = model.features.conv0.weight
w.data = torch.rand(len(w), 64, *w.shape[:2])

它在不影响其元数据的情况下取代底层卷积层的权重(例如。conv.in_channels仍与3相同),这可能会产生副作用。因此,我建议遵循第一种方法。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74633673

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档