我正在使用下面的代码来加载模型。
model.to(device)
checkpoint = torch.load("weights/vgg.pth")
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
ckpt = {k.replace('module.', ''):v for k,v in checkpoint.items()}
model.load_state_dict(ckpt)我发现了一个错误:
self.__class__.__name__,“n\t”.join(Error_msgs)) RuntimeError:为RepVGG加载state_dict中的错误:linear.weight的大小不匹配:从检查点复制带有形状torch.Size(1000,1280)的param,当前模型中的形状为torch.Size(8,1280)。Linear.bias的大小不匹配:从检查点复制带有形状torch.Size(1000)的参数,当前模型中的形状是torch.Size(8).
发布于 2022-09-07 08:31:13
现在的模型似乎被配置为提供8类的分类(num_class=8)。但是,您正在加载的检查点是一个VGG模型,它是在ImageNet上预先训练的,它有1000个类。因此,在最后一层中,权重和偏置的尺寸不匹配。
https://stackoverflow.com/questions/73631911
复制相似问题