首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >导出器classification_signature

导出器classification_signature
EN

Stack Overflow用户
提问于 2016-04-01 00:37:08
回答 1查看 798关注 0票数 1

我正在尝试修改serving tutorial以使用我的模型,该模型基本上是修改为使用CSV文件和JPEG的CIFAR示例。我似乎找不到Exporter类的文档,但这是我到目前为止所拥有的文档。它在cifar10_train.py文件的train()函数中:

代码语言:javascript
复制
  # Save the model checkpoint periodically.
  if step % 10 == 0 or (step + 1) == FLAGS.max_steps:
    checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
    saver.save(sess, checkpoint_path, global_step=step)

    export_dir = FLAGS.export_dir
    print 'Exporting trained model to ' + FLAGS.export_dir
    export_saver = tf.train.Saver(sharded=True)
    model_exporter = exporter.Exporter(export_saver)
    #
    # TODO: where to find x and y?
    #
    signature = exporter.classification_signature(input_tensor=x, scores_tensor=y)
    model_exporter.init(sess.graph.as_graph_def(),
                        default_graph_signature=signature)
    model_exporter.export(export_dir, tf.constant(FLAGS.export_version), sess)

下面是我用来训练模型的代码:

代码语言:javascript
复制
  labels = numpy.fromfile(os.path.join(data_dir, 'labels.txt'), dtype=numpy.int32, count=-1, sep='\n')

  filenames_and_labels = []

  start_image_number = 1
  end_image_number = 8200

  for i in xrange(start_image_number, end_image_number):
    file_name = os.path.join(data_dir, 'image%d.jpg' % i)
    label = labels[i - 1]
    filenames_and_labels.append(file_name + "," + str(label))


  print('Reading filenames for ' + str(len(filenames_and_labels)) + ' files (from ' + str(start_image_number) + ' to ' + str(end_image_number) + ')')

  for filename_and_label in filenames_and_labels:
    array = filename_and_label.split(",")
    f = array[0]
    # print(array)
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  # Create a queue that produces the filenames to read.
  filename_and_label_queue = tf.train.string_input_producer(filenames_and_labels)

  filename_and_label_tensor = filename_and_label_queue.dequeue()
  filename, label = tf.decode_csv(filename_and_label_tensor, [[""], [""]], ",")
  file_contents = tf.read_file(filename)
  image = tf.image.decode_jpeg(file_contents)

你知道如何正确设置导出器吗?

EN

回答 1

Stack Overflow用户

发布于 2016-05-05 03:12:22

请看一下MNIST export example

这显示了x和y是如何生成的,然后放在签名中。

此外,Inception example还展示了如何扩展现有模型以创建导出和服务。特别是,cifar10.inference调用看起来类似于inception_model.inference

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/36339059

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档