首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将图形(pb)转换为SavedModel进行gcloud ml-engine预测

将图形(pb)转换为SavedModel进行gcloud ml-engine预测
EN

Stack Overflow用户
提问于 2017-06-19 20:32:15
回答 1查看 1.4K关注 0票数 5

根据Google的德里克·周最近在谷歌云大数据和机器学习博客上发表文章,我使用云机器学习引擎训练了一个对象检测器,现在我想使用云机器学习引擎进行预测。

这些指令包括将Tensorflow图导出为output_inference_graph.pb的代码,但不包括如何将protobuf格式(pb)转换为gcloud预测所需的SavedModel格式。

我回顾了谷歌@rhaertel80 80的答复关于如何转换“Tensorflow for Poets”图像分类模型和如何转换“Tensorflow For Poets 2”图像分类模型的谷歌公司@MarkMcDonald提供的答案,但是对于博客文章中描述的对象检测器图(pb),两者都不起作用。

如何转换该物体探测器图(pb),以便它可以使用或gcloud引擎预测,请?

EN

回答 1

Stack Overflow用户

发布于 2018-08-31 10:13:10

这个帖子救了我!希望能帮助到这里的人。我使用导出成功的https://stackoverflow.com/a/48102615/6124383的方法

https://github.com/tensorflow/tensorflow/pull/15855/commits/81ec5d20935352d71ff56fac06c36d6ff0a7ae05

代码语言:javascript
复制
def export_model(sess, architecture, saved_model_dir):
  if architecture == 'inception_v3':
    input_tensor = 'DecodeJpeg/contents:0'
  elif architecture.startswith('mobilenet_'):
    input_tensor = 'input:0'
  else:
    raise ValueError('Unknown architecture', architecture)
  in_image = sess.graph.get_tensor_by_name(input_tensor)
  inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
   out_classes = sess.graph.get_tensor_by_name('final_result:0')
  outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
   signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs=inputs,
    outputs=outputs,
    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
  )
   legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
   # Save out the SavedModel.
  builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
  builder.add_meta_graph_and_variables(
    sess, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
    },
    legacy_init_op=legacy_init_op)
  builder.save()

#execute this in the final of def main(_):
export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)

parser.add_argument(
      '--saved_model_dir',
      type=str,
      default='/tmp/saved_models/1/',
      help='Where to save the exported graph.'
  )
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/44639463

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档