我使用一个定制的图像集来训练一个使用Tensorflow API的神经网络。在成功的培训过程中,我得到了这些检查点文件,其中包含了不同训练变量的值。现在我想从这些检查点文件中获得一个推理模型,我找到了这个脚本,我可以使用它生成深度梦图像,如本教程所解释的那样。问题是当我使用以下方法加载模型时:
import tensorflow as tf
model_fn = 'export'
graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)
with tf.gfile.FastGFile(model_fn, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
t_input = tf.placeholder(np.float32, name='input')
imagenet_mean = 117.0
t_preprocessed = tf.expand_dims(t_input-imagenet_mean, 0)
tf.import_graph_def(graph_def, {'input':t_preprocessed})我知道这个错误:
graph_def.ParseFromString(f.read()) Self.MergeFromString(序列化) 引发message_mod.DecodeError(“意外结束-组标签.”)google.protobuf.message.DecodeError:意外的结束-组标记。
脚本需要一个协议缓冲区文件,我不确定我用来生成推理模型的脚本是否给了我proto缓冲区文件。
请有人建议我做错了什么,或者有更好的方法来实现这一点。我只想将张量生成的检查点文件转换为proto缓冲区。
谢谢
发布于 2018-02-05 19:08:51
指向您运行的脚本的链接已经中断,但无论如何,推荐的做法不是尝试从检查点生成推理模型,而是在培训程序的末尾嵌入代码,该代码将发出"SavedModel“导出(这与检查点不一样)。
请参阅1,特别是标题“构建一个保存的模型”。请注意,保存的模型包含多个文件,其中一个确实是协议缓冲区(我希望直接回答您的问题);其他文件是可变文件和(可选)资产文件。
https://stackoverflow.com/questions/38698499
复制相似问题