首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >试图在一张图像上预测CNN火把的错误吗?

试图在一张图像上预测CNN火把的错误吗?
EN

Stack Overflow用户
提问于 2021-02-24 10:23:23
回答 2查看 151关注 0票数 0

错误消息

追溯(最近一次调用):文件"pred.py",第134行,在输出=模型(数据)运行时错误:4维权重16、3、3、3的预期4维输入,但得到了尺寸为1、32、32的三维输入。

预测码

代码语言:javascript
复制
normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467],
                                     std=[0.2471, 0.2435, 0.2616])
train_set = transforms.Compose([
                                 transforms.RandomCrop(32, padding=4),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 normalize,
                                     ])

model = models.condensenet(args)
model = nn.DataParallel(model)
PATH = "results/savedir/save_models/checkpoint_001.pth.tar"

model.load_state_dict(torch.load(PATH)['state_dict'])


device = torch.device("cpu")

model.eval()

image = Image.open("horse.jpg")
input = train_set(image)
train_loader = torch.utils.data.DataLoader(
        input,
        batch_size=1,shuffle=True, num_workers=1)
for i, data in enumerate(train_loader):
    
    #input_var = torch.autograd.Variable(data, volatile=True)
    #input_var = input_var.view(1, 3, 32,32)
    
    **output = model(data)
topk=(1,5)
maxk = max(topk)

_, pred = output.topk(maxk, 1, True, True)

当我试图对单个图像进行预测时,会出现这个错误吗? 图像形状/大小错误消息

链接到保存的模型

训练代码库

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-03-11 00:10:48

而不是执行for循环和train_loader,而是通过直接将输入传递到模型来解决这个问题。像这样

代码语言:javascript
复制
input = train_set(image)
input = input.unsqueeze(0)
model.eval()
output = model(input)

更详细的信息可以在这里找到链接

票数 0
EN

Stack Overflow用户

发布于 2021-02-24 16:13:35

请取消注释这一行#input_var = input_var.view(1, 3, 32,32),以便输入维数为4。

我猜你不会吧。如果输入通道为3,则使用input_var = input_var.view(1, 1, 32,32) (如果灰度)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66348912

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档