我正在尝试加载一个使用tensofrow2.x中的TensorFlow 1.x保存的模型。
使用tensorflow.keras.models.load_model加载旧模型时。我得到一个错误:
AttributeError: module 'tensorflow' has no attribute 'to_float'任何人对如何解决有建议:)。
发布于 2020-04-21 03:41:43
虽然这个错误的原因对我来说还不清楚,但我想建议一个通用的草图来处理奇怪的检查点情况。它应该可以在TensorFlow 2.1中工作。
checkpoint_filename = '/path/to/our/weird/checkpoint.ckpt'
model = tf.keras.Model( ... ) # TF2.0 Model to initialize with the above checkpoint
from tensorflow.python.training.checkpoint_utils import load_checkpoint, list_variables
reader = load_checkpoint(checkpoint_filename)
for w in model.weights:
name=w.name.split(':')[0] # See (b/29227106)
print(f"Loading {name}")
w.assign(reader.get_tensor(
# Variable renaming
{'/var_name1/in/model':'/var_name1/in/checkpoint',
'/var_name2/in/model':'/var_name2/in/checkpoint',
# ... and so on
}.get(name,name)))一般来说,模型变量应该有匹配的名称和形状。在名称不匹配的情况下,通过比较model.weights和list_variables的输出来检查差异,然后更新代码片段的重命名字典。请注意,此方法不会恢复模型的优化器状态。
https://stackoverflow.com/questions/61325148
复制相似问题