首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >keras模型与tf.keras模型的兼容性

keras模型与tf.keras模型的兼容性
EN

Stack Overflow用户
提问于 2019-07-22 18:51:30
回答 1查看 911关注 0票数 4

我感兴趣的是在tf.keras中训练一个模型,然后用keras加载它。我知道这不是很明智,但我对使用tf.keras来训练模型很感兴趣,因为

  1. tf.keras更容易构建输入管道。
  2. 我想利用tf.dataset API

我很有兴趣把它装上角因为

  1. 我想使用coreml将模型部署到ios。
  2. 我希望使用coremltools将我的模型转换为ios,而coreml工具只适用于keras,而不是tf.keras。

我遇到了一些路障,因为并不是所有的tf.keras层都可以作为keras层加载。例如,我没有遇到简单的DNN问题,因为在tf.keras和keras之间,所有的密集层参数都是相同的。但是,我在RNN层方面遇到了麻烦,因为tf.keras有一个keras没有的参数time_major。我的RNN层有time_major=False,这是与keras相同的行为,但是keras顺序层没有这个参数。

我现在的解决方案是将tf.keras模型保存在json文件中(用于模型结构),删除keras不支持的层的部分,还保存一个h5文件(用于权重),如下所示:

代码语言:javascript
复制
model = # model trained with tf.keras

# save json
model_json = model.to_json()
with open('path_to_model_json.json', 'w') as json_file:
    json_ = json.loads(model_json)
    layers = json_['config']['layers']
    for layer in layers:
        if layer['class_name'] == 'SimpleRNN':
            del layer['config']['time_major']
    json.dump(json_, json_file)

# save weights
model.save_weights('path_to_my_weights.h5')

然后,我使用coreml转换器工具从keras转换为coreml,如下所示:

代码语言:javascript
复制
with CustomObjectScope({'GlorotUniform': glorot_uniform()}):
    coreml_model = coremltools.converters.keras.convert(
        model=('path_to_model_json','path_to_my_weights.h5'),
        input_names=#inputs, 
        output_names=#outputs,
        class_labels = #labels, 
        custom_conversion_functions = { "GlorotUniform": tf.keras.initializers.glorot_uniform
                                            }
    )
    coreml_model.save('my_core_ml_model.mlmodel')

我的解决方案似乎奏效了,但我想知道是否有更好的办法?或者,这种做法是否存在迫在眉睫的危险?例如,是否有更好的方法将tf.keras模型转换为coreml?或者是否有更好的方法将tf.keras模型转换为keras?还是有更好的方法我还没想过?

如能就此提供任何建议,将不胜感激:)

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-08-13 16:37:12

你的方法对我来说很好!

在过去,当我不得不将tf.keras模型转换为keras模型时,我做了以下工作:

  • tf.keras中的列车模型
  • 只保存权重tf_model.save_weights("tf_model.hdf5")
  • 使Keras模型体系结构使用keras中的所有层(与tf keras 1相同)
  • 在keras中按层名加载权重:keras_model.load_weights(by_name=True)

这似乎对我有用。由于我使用的是开箱即用的体系结构(DenseNet169),所以我不得不将tf.keras网络复制到keras中的工作量非常少。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57152123

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档