首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >AttributeError:在导入TensorFlow模型Keras时,'Functional‘对象没有属性'predict_segmentation’

AttributeError:在导入TensorFlow模型Keras时,'Functional‘对象没有属性'predict_segmentation’
EN

Stack Overflow用户
提问于 2021-12-06 21:25:38
回答 1查看 602关注 0票数 1

我成功地训练了一个Keras模型,比如:

代码语言:javascript
复制
import tensorflow as tf
from keras_segmentation.models.unet import vgg_unet

# initaite the model
model = vgg_unet(n_classes=50, input_height=512, input_width=608)

# Train
model.train(
    train_images=train_images,
    train_annotations=train_annotations,
    checkpoints_path="/tmp/vgg_unet_1", epochs=5
)

并以hdf5格式保存:

代码语言:javascript
复制
tf.keras.models.save_model(model,'my_model.hdf5')

然后我装载我的模型

代码语言:javascript
复制
model=tf.keras.models.load_model('my_model.hdf5')

最后,我想对一幅新的图像进行分割预测。

代码语言:javascript
复制
out = model.predict_segmentation(
    inp=image_to_test,
    out_fname="/tmp/out.png"
)

我得到了以下错误:

代码语言:javascript
复制
AttributeError: 'Functional' object has no attribute 'predict_segmentation'

我做错什么了?是当我保存我的模型还是当我装载它的时候?

谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-12-06 23:09:17

predict_segmentation不是普通Keras模型中可用的函数。看起来它是在keras_segmentation库中创建模型之后添加的,这可能就是Keras不能再次加载它的原因。

我想你有两个选择。

  1. 您可以使用我链接的代码中的行手动将函数添加回模型。
代码语言:javascript
复制
model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)
  1. 在重新加载模型时,您可以使用相同的参数创建一个新的vgg_unet,并按照Keras文档中的建议,将权重从您的hdf5文件传递到该模型。
代码语言:javascript
复制
model = vgg_unet(n_classes=50, input_height=512, input_width=608)
model.load_weights('my_model.hdf5')
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70252159

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档