首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在tensorflow2.x中加载旧的keras模型

在tensorflow2.x中加载旧的keras模型
EN

Stack Overflow用户
提问于 2020-04-20 22:39:11
回答 1查看 161关注 0票数 0

我正在尝试加载一个使用tensofrow2.x中的TensorFlow 1.x保存的模型。

使用tensorflow.keras.models.load_model加载旧模型时。我得到一个错误:

代码语言:javascript
复制
AttributeError: module 'tensorflow' has no attribute 'to_float'

任何人对如何解决有建议:)。

EN

回答 1

Stack Overflow用户

发布于 2020-04-21 03:41:43

虽然这个错误的原因对我来说还不清楚,但我想建议一个通用的草图来处理奇怪的检查点情况。它应该可以在TensorFlow 2.1中工作。

代码语言:javascript
复制
checkpoint_filename = '/path/to/our/weird/checkpoint.ckpt'
model = tf.keras.Model( ... ) # TF2.0 Model to initialize with the above checkpoint

from tensorflow.python.training.checkpoint_utils import load_checkpoint, list_variables

reader = load_checkpoint(checkpoint_filename)
for w in model.weights:
    name=w.name.split(':')[0] # See (b/29227106)
    print(f"Loading {name}")
    w.assign(reader.get_tensor(
        # Variable renaming 
        {'/var_name1/in/model':'/var_name1/in/checkpoint',
         '/var_name2/in/model':'/var_name2/in/checkpoint',
         # ...  and so on
         }.get(name,name)))

一般来说,模型变量应该有匹配的名称和形状。在名称不匹配的情况下,通过比较model.weightslist_variables的输出来检查差异,然后更新代码片段的重命名字典。请注意,此方法不会恢复模型的优化器状态。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61325148

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档