我用图像分割-角化库创建了一个模型,方法是按如下方式初始化它:
import keras_segmentation
from keras_segmentation.models.unet import vgg_unet
from tensorflow.keras.layers import Input
model = vgg_unet(n_classes=21 , input_height=256, input_width=448)然后我把它训练成这样:
model.train(
train_images = "/content/drive/MyDrive/imgs_train/",
train_annotations = "/content/drive/MyDrive/masks_train/",
val_images = "/content/drive/MyDrive/mgs_validation/",
val_annotations = "/content/drive/MyDrive/masks_validation/",
checkpoints_path = "/content/drive/MyDrive/tmp/vgg_unet_1" ,
epochs=28,validate=True,callbacks = [myCallback])
model.load_weights('checkpoint_filepath')然后像这样保存它:
model.save('/content/drive/MyDrive/vgg_unet_segmentation.h5')然后按如下方式加载:
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')但是,当我试图通过执行out = model.predict_segmentation(inp=inp, out_fname="/tmp/out.png")进行预测时,我会得到以下错误:
AttributeError: 'Functional' object has no attribute 'predict_segmentation'因此,为了解决这个问题,我做了以下工作:
from types import MethodType
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')
model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)然而,这导致了另一个我无法解决的问题:
[<ipython-input-7-a4b7d02cd9a2>](https://localhost:8080/#) in <module>()
4 out = model.predict_segmentation(
5 inp=inp,
----> 6 out_fname="/tmp/out.png")
[/content/image-segmentation-keras/keras_segmentation/predict.py](https://localhost:8080/#) in predict(model, inp, out_fname, checkpoints_path, overlay_img, class_names, show_legends, colors, prediction_width, prediction_height, read_image_type)
148 assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 "
149
--> 150 output_width = model.output_width
151 output_height = model.output_height
152 input_width = model.input_width
AttributeError: 'Functional' object has no attribute 'output_width'知道为什么会发生这种情况吗?如果是的话,该如何解决呢?
任何帮助都是非常感谢的!
谢谢!
发布于 2022-04-14 11:18:20
为输出尝试model.predict(),如下代码所示:
prediction = (model.predict(test_img_input))
predicted_img=np.argmax(prediction, axis=3)[0,:,:]
plt.imshow(predicted_img)https://stackoverflow.com/questions/71852237
复制相似问题