我已经成功地训练了一个unet网络,目前我正试图验证这些预测。这个问题与我问here的这个问题有关。这个口罩应该有相同的尺寸,而且应该是单通道的,对吗?
请查找以下代码:
保存的模型如下:
#load weights to network
weights_path = unet_dir + "unet1.pt"
device = "cpu"
unet = UNet(in_channels=3, out_channels=3, init_features=8)
unet.to(device)
unet.load_state_dict(torch.load(weights_path, map_location=device))初步职能:
#define augmentations
inference_transform = A.Compose([
A.Resize(256, 256, always_apply=True),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
#define function for predictions
def predict(model, img, device):
model.eval()
with torch.no_grad():
images = img.to(device)
output = model(images)
predicted_masks = (output.squeeze() >= 0.5).float().cpu().numpy()
return(predicted_masks)
#define function to load image and output mask
def get_mask(img_path):
image = cv2.imread(img_path)
#assert image is not None
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_height, original_width = tuple(image.shape[:2])
image_trans = inference_transform(image = image)
image_trans = image_trans["image"]
image_trans = image_trans.unsqueeze(0)
image_mask = predict(unet, image_trans, device)
#image_mask = image_mask.astype(np.int16)
image_mask = cv2.resize(image_mask,(original_width, original_height),
interpolation=cv2.INTER_NEAREST)
#image_mask = cv2.resize(image_mask, (original_height, original_width))
#Y_train[n] = mask > 0.5
return(image_mask)#image example
example_path = "../input/test-image/10078.tiff"
image = cv2.imread(example_path)
#assert image is not None
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = get_mask(example_path)
#masked_img = image*np.expand_dims(mask, 2).astype("uint8")
#plot the image, mask and multiplied together
fig, (ax1, ax2) = plt.subplots(2)
ax1.imshow(image)
ax2.imshow(mask)
#ax3.imshow(masked_img)输出:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_4859/3003834023.py in <module>
13
14 ax1.imshow(image)
---> 15 ax2.imshow(mask)
16 #ax3.imshow(masked_img)
/opt/conda/lib/python3.7/site-packages/matplotlib/_api/deprecation.py in wrapper(*args, **kwargs)
457 "parameter will become keyword-only %(removal)s.",
458 name=name, obj_type=f"parameter of {func.__name__}()")
--> 459 return func(*args, **kwargs)
460
461 # Don't modify *func*'s signature, as boilerplate.py needs it.
/opt/conda/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
1412 def inner(ax, *args, data=None, **kwargs):
1413 if data is None:
-> 1414 return func(ax, *map(sanitize_sequence, args), **kwargs)
1415
1416 bound = new_sig.bind(ax, *args, **kwargs)
/opt/conda/lib/python3.7/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)
5485 **kwargs)
5486
-> 5487 im.set_data(X)
5488 im.set_alpha(alpha)
5489 if im.get_clip_path() is None:
/opt/conda/lib/python3.7/site-packages/matplotlib/image.py in set_data(self, A)
714 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
715 raise TypeError("Invalid shape {} for image data"
--> 716 .format(self._A.shape))
717
718 if self._A.ndim == 3:
TypeError: Invalid shape (2023, 2023, 256) for image data输出图像:

有人能在这件事上帮我吗?
谢谢&致以最良好的问候
施罗德·迈克尔
发布于 2022-09-22 07:36:37
首先,如果您使用vscode,我建议您使用它来调试:
https://marketplace.visualstudio.com/items?itemName=elazarcoh.simply-view-image-for-python-debugging
从我的脑海中,我会说,你应该把这些值相加在一个轴上(就像我想象的那样,它们是一个热编码的):
np.sum(A,axis = 1)
从这一点上说:
https://stackoverflow.com/questions/73807983
复制相似问题