首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在matplotlib中显示张量图像

在matplotlib中显示张量图像
EN

Stack Overflow用户
提问于 2019-03-10 09:13:41
回答 1查看 5.6K关注 0票数 4

我正在为Udacity的AI做一个使用Python nanodegree的项目。

我正在尝试显示从图像文件路径获得的torch.cuda.FloatTensor。下面的图像将是一个条形图,显示最有可能的5个花名及其相关的概率。

代码语言:javascript
复制
plt.figure(figsize=(3,3))
path = 'flowers/test/1/image_06743.jpg' 

top5_probs, top5_class_names = predict(path, model,5)

print(top5_probs)
print(top5_class_names)

flower_np_image = process_image(Image.open(path))
flower_tensor_image = torch.from_numpy(flower_np_image).type(torch.cuda.FloatTensor)
flower_tensor_image = flower_tensor_image.unsqueeze_(0)

axs = imshow(flower_tensor_image, ax = plt)
axs.axis('off')
axs.title(top5_class_names[0])
axs.show()
fig, ax = plt.subplots()
y_pos = np.arange(len(top5_class_names))
plt.barh(y_pos, list(reversed(top5_probs)))
plt.yticks(y_pos, list(reversed(top5_class_names)))
plt.ylabel('Flower Type')
plt.xlabel('Class Probability')

给我的imshow函数是这样的

代码语言:javascript
复制
def imshow(image, ax=None, title=None):
    if ax is None:
        fig, ax = plt.subplots()

    # PyTorch tensors assume the color channel is the first dimension
    # but matplotlib assumes is the third dimension
    image = image.transpose((1, 2, 0))

    # Undo preprocessing
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = std * image + mean

    # Image needs to be clipped between 0 and 1 or it looks like noise when displayed
    image = np.clip(image, 0, 1)

    ax.imshow(image)

    return ax

但是我得到了这个输出

代码语言:javascript
复制
[0.8310797810554504, 0.14590543508529663, 0.013837042264640331, 0.005048676859587431, 0.0027143193874508142]
['petunia', 'pink primrose', 'balloon flower', 'hibiscus', 'tree mallow']

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-17-f54be68feb7a> in <module>()
     12 flower_tensor_image = flower_tensor_image.unsqueeze_(0)
     13 
---> 14 axs = imshow(flower_tensor_image, ax = plt)
     15 axs.axis('off')
     16 axs.title(top5_class_names[0])

<ipython-input-15-9c543acc89cc> in imshow(image, ax, title)
      5     # PyTorch tensors assume the color channel is the first dimension
      6     # but matplotlib assumes is the third dimension
----> 7     image = image.transpose((1, 2, 0))
      8 
      9     # Undo preprocessing

TypeError: transpose(): argument 'dim0' (position 1) must be int, not tuple

<matplotlib.figure.Figure at 0x7f5855792160>

我的预测函数可以工作,但是imshow只会因为调用transpose而阻塞。有什么办法解决这个问题吗?我认为它隐约与转换回numpy数组有关。

我正在开发的笔记本可以在https://github.com/BozSteinkalt/ImageClassificationProject上找到

谢谢!

EN

回答 1

Stack Overflow用户

发布于 2019-03-10 14:34:08

您正尝试将numpy.transpose应用于torch.Tensor对象,因此改为调用tensor.transpose

您应该首先使用.numpy()flower_tensor_image转换为numpy

代码语言:javascript
复制
axs = imshow(flower_tensor_image.detach().cpu().numpy(), ax = plt)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55083571

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档