我有一些python代码,可以使用Tensorflow的TFRecords和Dataset API来训练网络。我已经使用tf.Keras.layers构建了网络,这无疑是最简单、最快的方法。方便的函数model_to_estimator()
modelTF = tf.keras.estimator.model_to_estimator(
keras_model=model,
custom_objects=None,
config=run_config,
model_dir=checkPointDirectory
)将Keras模型转换为估计器,这使我们能够很好地利用Dataset API,并在训练期间和训练完成时自动将检查点保存到checkPointDirectory。估计器API提供了一些宝贵的功能,例如自动在多个GPU上分配工作负载,例如
distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)现在,对于大型模型和大量数据,使用某种形式的保存模型在训练后执行预测通常是有用的。从Tensorflow 1.10 (参见https://github.com/tensorflow/tensorflow/issues/19295)开始,tf.keras.model对象支持来自Tensorflow检查点的load_weights()。在Tensorflow文档中简要提到了这一点,但在Keras文档中没有,而且我找不到任何人来展示这方面的示例。在一些新的.py中再次定义模型层之后,我尝试了
checkPointPath = os.path.join('.', 'tfCheckPoints', 'keras_model.ckpt.index')
model.load_weights(filepath=checkPointPath, by_name=False)但这给出了一个NotImplementedError:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
2018-10-01 14:24:49.912087:
Traceback (most recent call last):
File "C:/Users/User/PycharmProjects/python/mercury.classifier reductions/V3.2/wikiTestv3.2/modelEvaluation3.2.py", line 141, in <module>
model.load_weights(filepath=checkPointPath, by_name=False)
File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1526, in load_weights
checkpointable_utils.streaming_restore(status=status, session=session)
File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\training\checkpointable\util.py", line 880, in streaming_restore
"Streaming restore not supported from name-based checkpoints. File a "
NotImplementedError: Streaming restore not supported from name-based checkpoints. File a feature request if this limitation bothers you.我想按照警告的建议使用“基于对象的保护程序”,但我还没有找到通过传递给estimator.train()的RunConfig来做到这一点的方法。
那么,有没有更好的方法将保存的权重重新放回估计器中用于预测?github线程似乎表明这已经实现了(尽管基于错误,可能与我上面尝试的方式不同)。有没有人在TF检查点上成功使用过load_weights()?我还没有找到任何关于如何做到这一点的教程/示例,所以任何帮助都是非常感谢的。
发布于 2018-10-22 02:17:59
我不确定,但也许您可以将keras_model.ckpt.index更改为keras_model.ckpt进行测试。
发布于 2021-02-06 00:39:16
您可以创建单独的图形,正常加载检查点,然后将权重传输到Keras模型:
_graph = tf.Graph()
_sess = tf.Session(graph=_graph)
tf.saved_model.load(_sess, ['serve'], '../tf1_save/')
_weights_all, _bias_all = [], []
with _graph.as_default():
for idx, t_var in enumerate(tf.trainable_variables()):
# substitue variable_scope with your scope
if 'variable_scope/' not in t_var.name: break
print(t_var.name)
val = _sess.run(t_var)
_weights_all.append(val) if idx % 2 == 0 else _bias_all.append(val)
for layer, (weight, bias) in enumerate(zip(_weights_all, _bias_all)):
self.model.layers[layer].set_weights([np.array(weight), np.array(bias)])https://stackoverflow.com/questions/52597523
复制相似问题