首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >预测掩码图像有错误的维数unet TypeError:图像数据的无效形状(2023,2023,256)

预测掩码图像有错误的维数unet TypeError:图像数据的无效形状(2023,2023,256)
EN

Stack Overflow用户
提问于 2022-09-21 23:49:03
回答 1查看 79关注 0票数 0

我已经成功地训练了一个unet网络,目前我正试图验证这些预测。这个问题与我问here的这个问题有关。这个口罩应该有相同的尺寸,而且应该是单通道的,对吗?

请查找以下代码:

保存的模型如下:

代码语言:javascript
复制
#load weights to network
weights_path = unet_dir + "unet1.pt"
device = "cpu"

unet = UNet(in_channels=3, out_channels=3, init_features=8)
unet.to(device)
unet.load_state_dict(torch.load(weights_path, map_location=device))

初步职能:

代码语言:javascript
复制
#define augmentations 
inference_transform = A.Compose([
    A.Resize(256, 256, always_apply=True),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 
    ToTensorV2()
])

#define function for predictions
def predict(model, img, device):
    model.eval()
    with torch.no_grad():
        images = img.to(device)
        output = model(images)
        predicted_masks = (output.squeeze() >= 0.5).float().cpu().numpy()
        
    return(predicted_masks)

#define function to load image and output mask
def get_mask(img_path):
    image = cv2.imread(img_path)
    #assert image is not None
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_height, original_width = tuple(image.shape[:2])
    
    image_trans = inference_transform(image = image)
    image_trans = image_trans["image"]
    image_trans = image_trans.unsqueeze(0)
    
    image_mask = predict(unet, image_trans, device)
    #image_mask = image_mask.astype(np.int16)
    image_mask = cv2.resize(image_mask,(original_width, original_height),
                          interpolation=cv2.INTER_NEAREST)
    #image_mask = cv2.resize(image_mask, (original_height, original_width))
    #Y_train[n] = mask > 0.5    
    return(image_mask)
代码语言:javascript
复制
#image example
example_path = "../input/test-image/10078.tiff"
image = cv2.imread(example_path)
#assert image is not None
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

mask = get_mask(example_path)

#masked_img = image*np.expand_dims(mask, 2).astype("uint8")

#plot the image, mask and multiplied together
fig, (ax1, ax2) = plt.subplots(2)

ax1.imshow(image)
ax2.imshow(mask)
#ax3.imshow(masked_img)

输出:

代码语言:javascript
复制
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_4859/3003834023.py in <module>
     13 
     14 ax1.imshow(image)
---> 15 ax2.imshow(mask)
     16 #ax3.imshow(masked_img)

/opt/conda/lib/python3.7/site-packages/matplotlib/_api/deprecation.py in wrapper(*args, **kwargs)
    457                 "parameter will become keyword-only %(removal)s.",
    458                 name=name, obj_type=f"parameter of {func.__name__}()")
--> 459         return func(*args, **kwargs)
    460 
    461     # Don't modify *func*'s signature, as boilerplate.py needs it.

/opt/conda/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1412     def inner(ax, *args, data=None, **kwargs):
   1413         if data is None:
-> 1414             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1415 
   1416         bound = new_sig.bind(ax, *args, **kwargs)

/opt/conda/lib/python3.7/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)
   5485                               **kwargs)
   5486 
-> 5487         im.set_data(X)
   5488         im.set_alpha(alpha)
   5489         if im.get_clip_path() is None:

/opt/conda/lib/python3.7/site-packages/matplotlib/image.py in set_data(self, A)
    714                 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
    715             raise TypeError("Invalid shape {} for image data"
--> 716                             .format(self._A.shape))
    717 
    718         if self._A.ndim == 3:

TypeError: Invalid shape (2023, 2023, 256) for image data

输出图像:

有人能在这件事上帮我吗?

谢谢&致以最良好的问候

施罗德·迈克尔

EN

回答 1

Stack Overflow用户

发布于 2022-09-22 07:36:37

首先,如果您使用vscode,我建议您使用它来调试:

https://marketplace.visualstudio.com/items?itemName=elazarcoh.simply-view-image-for-python-debugging

从我的脑海中,我会说,你应该把这些值相加在一个轴上(就像我想象的那样,它们是一个热编码的):

np.sum(A,axis = 1)

从这一点上说:

sum numpy ndarray with 3d array along a given axis 1

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73807983

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档