首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >一幅图像的pytorch vgg模型试验

一幅图像的pytorch vgg模型试验
EN

Stack Overflow用户
提问于 2018-12-10 17:01:26
回答 1查看 1.8K关注 0票数 0

我训练了一个vgg模型,这就是我转换测试数据的方式

代码语言:javascript
复制
test_transform_2= transforms.Compose([transforms.RandomResizedCrop(224), 
                                     transforms.ToTensor()])

test_data = datasets.ImageFolder(test_dir, transform=test_transform_2)

模型已经训练完毕,现在我想在一个图像上测试它

代码语言:javascript
复制
from scipy import misc

test_image = misc.imread('flower_data/valid/1/image_06739.jpg')
vgg16(torch.from_numpy(test_image))

错误

代码语言:javascript
复制
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-60-b83587325fea> in <module>
----> 1 vgg16(torch.from_numpy(test_image))

c:\users\sam\mydocu~1\code\envs\data-science\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

c:\users\sam\mydocu~1\code\envs\data-science\lib\site-packages\torchvision\models\vgg.py in forward(self, x)
     40 
     41     def forward(self, x):
---> 42         x = self.features(x)
     43         x = x.view(x.size(0), -1)
     44         x = self.classifier(x)

c:\users\sam\mydocu~1\code\envs\data-science\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

c:\users\sam\mydocu~1\code\envs\data-science\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
     89     def forward(self, input):
     90         for module in self._modules.values():
---> 91             input = module(input)
     92         return input
     93 

c:\users\sam\mydocu~1\code\envs\data-science\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

c:\users\sam\mydocu~1\code\envs\data-science\lib\site-packages\torch\nn\modules\conv.py in forward(self, input)
    299     def forward(self, input):
    300         return F.conv2d(input, self.weight, self.bias, self.stride,
--> 301                         self.padding, self.dilation, self.groups)
    302 
    303 

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 3, 3], but got input of size [628, 500, 3] instead

我可以看出,我需要塑造输入,但我不知道如何基于它似乎期望输入是一批通知的方式。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-12-10 17:34:55

您的图像是[h, w, 3],其中3表示rgb通道,py手电期望[b, 3, h, w],其中b是批处理大小。所以你可以通过调用reshaped = img.permute(2, 0, 1).unsqueeze(0)来重塑它。我认为在某个地方也有一个实用函数,但我现在找不到它。

所以在你的情况下

代码语言:javascript
复制
tensor = torch.from_numpy(test_image)
reshaped = tensor.permute(2, 0 1).unsqueeze(0)
your_result = vgg16(reshaped)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53710313

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档