首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从量化的TFLite中获取类索引?

如何从量化的TFLite中获取类索引?
EN

Stack Overflow用户
提问于 2021-01-20 02:25:14
回答 2查看 126关注 0票数 2

我一直在用TensorFlow训练一个量化的Mobilenet V2,但我不知道如何从它获得类索引。

我使用的是Tensorflow 1.12

下面是我的输入和输出的详细信息。

代码语言:javascript
复制
Input details [{'name': 'normalized_input_image_tensor', 'index': 260, 'shape': array([  1, 300, 300,   3], dtype=int32), 'shape_signature': array([  1, 300, 300,   3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128), 'quantization_parameters': {'scales': array([0.0078125], dtype=float32), 'zero_points': array([128], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
Output details [{'name': 'TFLite_Detection_PostProcess', 'index': 252, 'shape': array([ 1, 10,  4], dtype=int32), 'shape_signature': array([ 1, 10,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 253, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 254, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 255, 'shape': array([1], dtype=int32), 'shape_signature': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

我一直在尝试通过执行以下操作来获取类索引:

代码语言:javascript
复制
interpreter = tf.lite.Interpreter(model_path=PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
classes = interpreter.get_tensor(output_details[1]['index'])[0]

然而,类索引是不正确的。打印时,classes看起来像这样:[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]。我的数据集中有1个以上的类,所以这没有任何意义。

获取类索引的正确方法是什么?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-01-23 08:38:56

经过大量的实验,结果证明这不是一个量化问题。我们在创建.tflite时使用了错误的graph_def .pb文件,因此它预测了不存在的类。

票数 1
EN

Stack Overflow用户

发布于 2021-01-20 04:19:11

尝试使用:

代码语言:javascript
复制
classes = interpreter.get_tensor(output_details[0]['index'])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65797318

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档