首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >DenseNet,张量的大小必须匹配

DenseNet,张量的大小必须匹配
EN

Stack Overflow用户
提问于 2020-11-24 18:11:11
回答 1查看 186关注 0票数 1

你知道我如何调整这段代码,使张量的大小必须匹配,因为我有这个错误:x = torch.cat([x1,x2],1) RuntimeError: Sizes of tensors must match except in dimension 0. Got 32 and 1 (The offending index is 0)

我的图片尺寸是416x416。

提前感谢你的帮助,

代码语言:javascript
复制
num_classes = 20
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
                
        self.inc = models.inception_v3(pretrained=True)
        self.inc.aux_logits = False

        for child in list(self.inc.children())[:-5]:
            for param in child.parameters():
                param.requires_grad = False

        self.inc.fc = nn.Sequential()
                    
        self.dens121 = models.densenet121(pretrained=True)

        for child in list(self.dens121.children())[:-6]:
            for param in child.parameters():
                param.requires_grad = False

        self.dens121 = nn.Sequential(*list(self.dens121.children())[:-1])
           
        self.SiLU = nn.SiLU()      
        self.linear = nn.Linear(4096, num_classes)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x1 = self.SiLU(self.dens121(x))
        x1 = x1.view(-1, 2048)
        
        x2 = self.inc(x).view(-1, 2048)
        x = torch.cat([x1,x2],1)

        return self.linear(self.dropout(x))
EN

回答 1

Stack Overflow用户

发布于 2020-11-24 20:40:02

这两个张量的形状非常不同,这就是torch.cat()失败的原因。我试着用下面的例子运行你的代码:

代码语言:javascript
复制
def forward(self, x):
    x1 = self.SiLU(self.dens121(x))
    x1 = x1.view(-1, 2048)
        
    x2 = self.inc(x).view(-1, 2048)
    print(x1.shape, x2.shape)
    x = torch.cat([x1,x2], dim=1)

    return self.linear(self.dropout(x))

下面是驱动程序代码

代码语言:javascript
复制
inputs = torch.randn(2, 3, 416, 416)
model = Net()
outputs = model(inputs)

x2的x1形状如下:

代码语言:javascript
复制
torch.Size([169, 2048]) torch.Size([2, 2048])

您的DenseNet输出的形状应与Inceptionv3的输出相同,反之亦然。DenseNet的输出为torch.Size([2, 1024, 13, 13])形状,Inceptionv3的输出为torch.Size([2, 2048])形状。

EDIT将以下行添加到init方法:

代码语言:javascript
复制
self.conv_reshape= nn.Conv2d(1024, 2048, kernel_size=13, stride=1)

将这些行添加到您的forward()

代码语言:javascript
复制
x1 = self.SiLU(self.dens121(x))

out = self.conv_reshape(x1)
x1 = out.view(-1, out.size(1))

x2 = self.inc(x).view(-1, 2048)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64984301

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档