首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用tflite为android转换Mask-RCNN模型

使用tflite为android转换Mask-RCNN模型
EN

Stack Overflow用户
提问于 2020-05-03 16:05:39
回答 1查看 1.1K关注 0票数 1

我正在尝试在android中部署这个mask-rcnn模型。我能够加载keras权重,冻结模型,并使用this script使用tflite 1.13toco将其转换为.tflite模型。

似乎这个模型使用了一些在tflite中不支持的tf_ops。因此我不得不使用

代码语言:javascript
复制
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]

转换模型的步骤。现在,当我尝试使用python解释器来推断这个模型时,我在interpreter.invoke()中得到分割错误,并且python脚本崩溃。

代码语言:javascript
复制
def run_tf_model(model_path="mask_rcnn_coco.tflite"):
    interpreter = tf.lite.Interpreter(model_path)
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]
    print(" input_details", input_details)
    print("output_details",output_details)

    # Test model on random input data.
    input_shape = input_details['shape']
    print("input_shape tflite",input_shape)
    input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
    interpreter.set_tensor(input_details['index'], input_data)
    interpreter.invoke()

    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    output_data = interpreter.get_tensor(output_details['index'])
    print(output_data)

因此,我无法确定转换后的模型是否已正确转换。

附注:我计划在android中使用此模型,但我对android(java或kotlin) tflite api几乎没有经验。如果任何人能指出任何学习资源,也将是有帮助的。

编辑:我也尝试了用java api在android上运行推理。但是会得到以下错误tensorflow/lite/kernels/gather.cc:80 0 <= axis && axis < NumDimensions(input). Detailed in this tensorflow issue

EN

回答 1

Stack Overflow用户

发布于 2020-09-06 00:05:50

您可以使用tflite python解释器验证您的自定义训练的TFLite模型。Reference

代码语言:javascript
复制
import numpy as np
import tensorflow as tf

# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61571375

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档