首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Tensorflow-Serving中,有可能只获得top-k预测结果吗?

在Tensorflow-Serving中,有可能只获得top-k预测结果吗?
EN

Stack Overflow用户
提问于 2018-12-08 05:55:07
回答 2查看 414关注 0票数 2

当使用https://www.tensorflow.org/serving中的代码,但是使用DNNClassifier估计器模型时,curl/query请求返回所有可能的标签类别及其关联的分数。

使用具有100,000+可能的输出/标签类别的模型,响应变得太大。有没有办法将输出数量限制在前k个结果中?(类似于keras中的实现方式)。

我能想到的唯一可能是通过签名将一些参数提供给predict API,但我还没有找到任何可以提供此功能的参数。我已经阅读了一大堆文档和代码,谷歌了一大堆,但都没有用。

任何帮助都将不胜感激。提前感谢您的回复。<3

EN

回答 2

Stack Overflow用户

发布于 2018-12-17 17:07:09

AFAIC,有两种方法可以满足您的需求。

  1. 您可以在tensorflow中添加一些行,引用tensorflow的源代码可以在训练/重新训练模型时执行类似this的操作。

希望这能有所帮助。

票数 2
EN

Stack Overflow用户

发布于 2019-03-18 23:03:02

把这个放在这里以防对任何人有帮助。可以覆盖head.py (dnn.py使用的)中的classification_output()函数,以便过滤top-k结果。您可以将此代码片段插入到main.py / train.py文件中,当您保存DNNClassifier模型时,该模型在进行推理/服务时将始终输出至多num_top_k_results。该方法的绝大多数内容都是从原始的classification_output()函数复制而来的。(请注意,这可能会与1.13 / 2.0一起工作,也可能不会,因为它还没有在这些平台上测试过。)

代码语言:javascript
复制
from tensorflow.python.estimator.canned import head as head_lib

num_top_k_results = 5

def override_classification_output(scores, n_classes, label_vocabulary=None):
  batch_size = array_ops.shape(scores)[0]
  if label_vocabulary:
    export_class_list = label_vocabulary
  else:
    export_class_list = string_ops.as_string(math_ops.range(n_classes))
  # Get the top_k results
  top_k_scores, top_k_indices = tf.nn.top_k(scores, num_top_k_results)
  # Using the top_k_indices, get the associated class names (from the vocabulary)
  top_k_classes = tf.gather(tf.convert_to_tensor(value=export_class_list), tf.squeeze(top_k_indices))
  export_output_classes = array_ops.tile(
      input=array_ops.expand_dims(input=top_k_classes, axis=0),
      multiples=[batch_size, 1])
  return export_output.ClassificationOutput(
      scores=top_k_scores,
      # `ClassificationOutput` requires string classes.
      classes=export_output_classes)

# Override the original method with our custom one.
head_lib._classification_output = override_classification_output
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53677274

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档