首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Keras_Segmentation VGG导致AttributeError:'Functional‘object没有属性'output_width’

Keras_Segmentation VGG导致AttributeError:'Functional‘object没有属性'output_width’
EN

Stack Overflow用户
提问于 2022-04-13 05:17:03
回答 1查看 174关注 0票数 0

我用图像分割-角化库创建了一个模型,方法是按如下方式初始化它:

代码语言:javascript
复制
import keras_segmentation
from keras_segmentation.models.unet import vgg_unet
from tensorflow.keras.layers import Input

model = vgg_unet(n_classes=21 ,  input_height=256, input_width=448)

然后我把它训练成这样:

代码语言:javascript
复制
model.train(
    train_images =  "/content/drive/MyDrive/imgs_train/",
    train_annotations = "/content/drive/MyDrive/masks_train/",
    val_images =  "/content/drive/MyDrive/mgs_validation/",
    val_annotations = "/content/drive/MyDrive/masks_validation/",
    checkpoints_path = "/content/drive/MyDrive/tmp/vgg_unet_1" , 
    epochs=28,validate=True,callbacks = [myCallback])

model.load_weights('checkpoint_filepath')

然后像这样保存它:

代码语言:javascript
复制
model.save('/content/drive/MyDrive/vgg_unet_segmentation.h5')

然后按如下方式加载:

代码语言:javascript
复制
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')

但是,当我试图通过执行out = model.predict_segmentation(inp=inp, out_fname="/tmp/out.png")进行预测时,我会得到以下错误:

代码语言:javascript
复制
AttributeError: 'Functional' object has no attribute 'predict_segmentation'

因此,为了解决这个问题,我做了以下工作:

代码语言:javascript
复制
from types import MethodType
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')
model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)

然而,这导致了另一个我无法解决的问题:

代码语言:javascript
复制
[<ipython-input-7-a4b7d02cd9a2>](https://localhost:8080/#) in <module>()
      4 out = model.predict_segmentation(
      5     inp=inp,
----> 6     out_fname="/tmp/out.png")

[/content/image-segmentation-keras/keras_segmentation/predict.py](https://localhost:8080/#) in predict(model, inp, out_fname, checkpoints_path, overlay_img, class_names, show_legends, colors, prediction_width, prediction_height, read_image_type)
    148     assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 "
    149 
--> 150     output_width = model.output_width
    151     output_height = model.output_height
    152     input_width = model.input_width

AttributeError: 'Functional' object has no attribute 'output_width'

知道为什么会发生这种情况吗?如果是的话,该如何解决呢?

任何帮助都是非常感谢的!

谢谢!

EN

回答 1

Stack Overflow用户

发布于 2022-04-14 11:18:20

为输出尝试model.predict(),如下代码所示:

代码语言:javascript
复制
prediction = (model.predict(test_img_input))
predicted_img=np.argmax(prediction, axis=3)[0,:,:]
plt.imshow(predicted_img)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71852237

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档