首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >RuntimeError:预期为4维权重的4维输入

RuntimeError:预期为4维权重的4维输入
EN

Stack Overflow用户
提问于 2021-06-16 18:37:56
回答 1查看 190关注 0票数 0

我有一个网络,其中有3个架构共享相同的分类器。

代码语言:javascript
复制
class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels,batch_norm=False):

        super(VGGBlock,self).__init__()

        conv2_params = {'kernel_size': (3, 3),
                        'stride'     : (1, 1),
                        'padding'   : 1
                        }

        noop = lambda x : x

        self._batch_norm = batch_norm

        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels , **conv2_params)
        self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels, **conv2_params)
        self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop

        self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

    @property
    def batch_norm(self):
        return self._batch_norm

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.max_pooling(x)

        return x
代码语言:javascript
复制
class VGG16(nn.Module):

  def __init__(self, input_size, num_classes=1,batch_norm=False):
    super(VGG16, self).__init__()

    self.in_channels,self.in_width,self.in_height = input_size

    self.block_1 = VGGBlock(self.in_channels,64,batch_norm=batch_norm)
    self.block_2 = VGGBlock(64, 128,batch_norm=batch_norm)
    self.block_3 = VGGBlock(128, 256,batch_norm=batch_norm)
    self.block_4 = VGGBlock(256,512,batch_norm=batch_norm)

  @property
  def input_size(self):
      return self.in_channels,self.in_width,self.in_height

  def forward(self, x):

    x = self.block_1(x)
    x = self.block_2(x)
    x = self.block_3(x)
    x = self.block_4(x)

    return x
代码语言:javascript
复制
class VGG16Classifier(nn.Module):

  def __init__(self, num_classes=1,classifier = None,batch_norm=False):
    super(VGG16Classifier, self).__init__()


    self._vgg_a = VGG16((1,32,32),batch_norm=True)
    self._vgg_b = VGG16((1,32,32),batch_norm=True)
    self._vgg_star = VGG16((1,32,32),batch_norm=True)
    self.classifier = classifier

    if (self.classifier is None):
        self.classifier = nn.Sequential(
          nn.Linear(2048, 2048),
          nn.ReLU(True),
          nn.Dropout(p=0.5),
          nn.Linear(2048, 512),
          nn.ReLU(True),
          nn.Dropout(p=0.5),
          nn.Linear(512, num_classes)
        )

  def forward(self, x1,x2,x3):
      op1 = self._vgg_a(x1)
      op1 = torch.flatten(op1,1)
      op2 = self._vgg_b(x2)
      op2 = torch.flatten(op2,1)
      op3 = self._vgg_star(x3) 
      op3 = torch.flatten(op3,1)
      
      x1 = self.classifier(op1)
      x2 = self.classifier(op2)
      x3 = self.classifier(op3)

      return x1,x2,x3
代码语言:javascript
复制
model1 = VGG16((1,32,32),batch_norm=True)
model2 = VGG16((1,32,32),batch_norm=True)
model_star = VGG16((1,32,32),batch_norm=True)
model_combo = VGG16Classifier(model1,model2,model_star)

我想使用以下损失函数来训练model_combo:

代码语言:javascript
复制
class CombinedLoss(nn.Module):
    def __init__(self, loss_a, loss_b, loss_star, _lambda=1.0):
        super().__init__()
        self.loss_a = loss_a
        self.loss_b = loss_b
        self.loss_star = loss_star

        self.register_buffer('_lambda',torch.tensor(float(_lambda),dtype=torch.float32))


    def forward(self,y_hat,y):

        return (self.loss_a(y_hat[0],y[0]) + 
                self.loss_b(y_hat[1],y[1]) + 
                self.loss_combo(y_hat[2],y[2]) + 
                self._lambda * torch.sum(model_star.weight - torch.pow(torch.cdist(model1.weight+model2.weight), 2)))

在训练函数中,我传递了装载器,为了简单起见,它们是loaders_a、loaders_b和loaders_a,其中loaders_a与MNIST的前50%数据相关,loaders_b与MNIST的后50%数据相关。

代码语言:javascript
复制
def train(net, loaders, optimizer, criterion, epochs=20, dev=None, save_param=False, model_name="valerio"):
      loaders_a, loaders_b, loaders_star = loaders
    # try:
      net = net.to(dev)
      #print(net)
      #summary(net,[(net.in_channels,net.in_width,net.in_height)]*2)


      criterion.to(dev)


      # Initialize history
      history_loss = {"train": [], "val": [], "test": []}
      history_accuracy_a = {"train": [], "val": [], "test": []}
      history_accuracy_b = {"train": [], "val": [], "test": []}
      history_accuracy_star = {"train": [], "val": [], "test": []}
      # Store the best val accuracy
      best_val_accuracy = 0

      # Process each epoch
      for epoch in range(epochs):
        # Initialize epoch variables
        sum_loss = {"train": 0, "val": 0, "test": 0}
        sum_accuracy_a = {"train": 0, "val": 0, "test": 0}
        sum_accuracy_b = {"train": 0, "val": 0, "test": 0}
        sum_accuracy_star = {"train": 0, "val": 0, "test": 0}

        progbar = None
        # Process each split
        for split in ["train", "val", "test"]:
          if split == "train":
            net.train()
            #widgets = [
              #' [', pb.Timer(), '] ',
              #pb.Bar(),
              #' [', pb.ETA(), '] ', pb.Variable('ta','[Train Acc: {formatted_value}]')]

            #progbar = pb.ProgressBar(max_value=len(loaders_a[split]),widgets=widgets,redirect_stdout=True)

          else:
            net.eval()
          # Process each batch
          for j, ((input_a, labels_a), (input_b, labels_b), (input_s, labels_s)) in enumerate(zip(loaders_a[split], loaders_b[split], loaders_star[split])):
            labels_a = labels_a.unsqueeze(1).float()
            labels_b = labels_b.unsqueeze(1).float()
            labels_s = labels_s.unsqueeze(1).float()

            input_a = input_a.to(dev)
            labels_a = labels_a.to(dev)
            input_b = input_b.to(dev)
            labels_b = labels_b.to(dev)
            input_s = input_s.to(dev)
            labels_s = labels_s.to(dev)

            # Reset gradients
            optimizer.zero_grad()
            # Compute output
            pred = net(input_a,input_b, input_s)

            loss = criterion(pred, [labels_a, labels_b, labels_s])
            # Update loss
            sum_loss[split] += loss.item()
            # Check parameter update
            if split == "train":
              # Compute gradients
              loss.backward()
              # Optimize
              optimizer.step()

            # Compute accuracy
            pred_labels = (pred[2] >= 0.0).long()  # Binarize predictions to 0 and 1
            pred_labels_a = (pred[0] >= 0.0).long()  # Binarize predictions to 0 and 1
            pred_labels_b = (pred[1] >= 0.0).long()  # Binarize predictions to 0 and 1


            batch_accuracy_star = (pred_labels == labels_s).sum().item() / len(labels_s)
            batch_accuracy_a = (pred_labels_a == labels_a).sum().item() / len(labels_a)
            batch_accuracy_b = (pred_labels_b == labels_b).sum().item() / len(labels_b)
            # Update accuracy
            sum_accuracy_star[split] += batch_accuracy_star
            sum_accuracy_a[split] += batch_accuracy_a
            sum_accuracy_b[split] += batch_accuracy_b

            #if (split=='train'):
              #progbar.update(j, ta=batch_accuracy)
              #progbar.update(j, ta=batch_accuracy_a)
              #progbar.update(j, ta=batch_accuracy_b)

        #if (progbar is not None):
          #progbar.finish()
        # Compute epoch loss/accuracy
        #for split in ["train", "val", "test"]:
          #epoch_loss = sum_loss[split] / (len(loaders_a[split])+len(loaders_b[split])) 
          #epoch_accuracy_combo = {split: sum_accuracy_combo[split] / len(loaders[split]) for split in ["train", "val", "test"]}
          #epoch_accuracy_a = sum_accuracy_a[split] / len(loaders_a[split])
          #epoch_accuracy_b = sum_accuracy_b[split] / len(loaders_b[split])
        epoch_loss = sum_loss["train"] / (len(loaders_a["train"])+len(loaders_b["train"])+len(loaders_s["train"])) 
        epoch_accuracy_a = sum_accuracy_a["train"] / len(loaders_a["train"])
        epoch_accuracy_b = sum_accuracy_b["train"] / len(loaders_b["train"])
        epoch_accuracy_star = sum_accuracy_star["train"] / len(loaders_s["train"]) 

        epoch_loss_val = sum_loss["val"] / (len(loaders_a["val"])+len(loaders_b["val"])+len(loaders_s["val"])) 
        epoch_accuracy_a_val = sum_accuracy_a["val"] / len(loaders_a["val"])
        epoch_accuracy_b_val = sum_accuracy_b["val"] / len(loaders_b["val"])
        epoch_accuracy_star_val = sum_accuracy_star["val"] / len(loaders_s["val"]) 

        epoch_loss_test = sum_loss["test"] / (len(loaders_a["test"])+len(loaders_b["test"])+len(loaders_s["test"])) 
        epoch_accuracy_a_test = sum_accuracy_a["test"] / len(loaders_a["test"])
        epoch_accuracy_b_test = sum_accuracy_b["test"] / len(loaders_b["test"])
        epoch_accuracy_star_test = sum_accuracy_star["test"] / len(loaders_s["test"]) 


        # Store params at the best validation accuracy
        if save_param and epoch_accuracy["val"] > best_val_accuracy:
          # torch.save(net.state_dict(), f"{net.__class__.__name__}_best_val.pth")
          torch.save(net.state_dict(), f"{model_name}_best_val.pth")
          best_val_accuracy = epoch_accuracy["val"]

        # Update history
        for split in ["train", "val", "test"]:
          history_loss[split].append(epoch_loss)
          history_accuracy_a[split].append(epoch_accuracy_a)
          history_accuracy_b[split].append(epoch_accuracy_b)
          history_accuracy_star[split].append(epoch_accuracy_star)
        # Print info
        print(f"Epoch {epoch + 1}:",
              f"Training Loss = {epoch_loss:.4f},",)
        print(f"Epoch {epoch + 1}:",
              f"Training Accuracy for A = {epoch_accuracy_a:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Training Accuracy for B = {epoch_accuracy_b:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Training Accuracy for star = {epoch_accuracy_star:.4f},")
        
        print(f"Epoch {epoch + 1}:",
              f"Val Loss = {epoch_loss_val:.4f},",)
        print(f"Epoch {epoch + 1}:",
              f"Val Accuracy for A = {epoch_accuracy_a_val:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Val Accuracy for B = {epoch_accuracy_b_val:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Val Accuracy for star = {epoch_accuracy_star_val:.4f},")
        
        print(f"Epoch {epoch + 1}:",
              f"Test Loss = {epoch_loss_test:.4f},",)
        print(f"Epoch {epoch + 1}:",
              f"Test Accuracy for A = {epoch_accuracy_a_test:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Test Accuracy for B = {epoch_accuracy_b_test:.4f},")
        print(f"Epoch {epoch + 1}:",
              f"Test Accuracy for star = {epoch_accuracy_star_test:.4f},")
        print("\n")

但是我得到了这个错误:

代码语言:javascript
复制
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 1, 3, 3], but got 2-dimensional input of size [128, 2048] instead
EN

回答 1

Stack Overflow用户

发布于 2021-06-16 19:40:40

从您的代码和错误中,我猜测您正在向网络传递二进制图像(h,w,1)。这个问题出现在Conv2d层,它需要4维的输入。换句话说- Conv2d layer期望4维张量如下:

代码语言:javascript
复制
T = torch.randn(1,3,128,256)
print(T.shape)
out: torch.Size([1, 3, 128, 256])

其中:

第一个维度(编号1)是batch dimension,用于在此维度上堆叠多个张量以执行batch operation.

  • Second dimension (编号3) is in_channels

  • Conv。它基本上是图像的通道数。标准RGB或BGR图像有3个channels.

  • Third维度(编号128)是高度维度( dimension.

  • Fourth dimension )(编号256)是宽度维度。

二进制映像有1个通道维度:

代码语言:javascript
复制
[128, 256, 1] , [Height, Width, Channels]
OR 
[128, 256], [Height, Width]

考虑一下:

标准块像阵列dims具有H,W,C形状,其中torch在批次维度之后期望通道维度,因此: B,C,H,W。

我不确定通道钳位发生在哪里,但二值图像变成了2维图像,因为只要通道尺寸只有一个,就不需要通道尺寸了。

如果你想把2维的二值图像传递给Conv2d层,你应该把它解压缩成4维张量。

代码语言:javascript
复制
Before: Input.shape = torch.Size[128, 2048]

预处理:

代码语言:javascript
复制
Tensor = Input.view(1, 1, Input.shape[0], Input.shape[1])
Tensor.shape
out: = torch.Size[1, 1, 128, 2048]

同样的方法也可以通过两次解压第零个dim来完成:

代码语言:javascript
复制
Tensor = Input.unsqueeze(0).unsqueeze(0)
Tensor.shape
out: = torch.Size[1, 1, 128, 2048]

但它更混乱-所以我推荐第一种选择。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68001067

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档