我正在尝试在android中部署这个mask-rcnn模型。我能够加载keras权重,冻结模型,并使用this script使用tflite 1.13toco将其转换为.tflite模型。
似乎这个模型使用了一些在tflite中不支持的tf_ops。因此我不得不使用
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]转换模型的步骤。现在,当我尝试使用python解释器来推断这个模型时,我在interpreter.invoke()中得到分割错误,并且python脚本崩溃。
def run_tf_model(model_path="mask_rcnn_coco.tflite"):
interpreter = tf.lite.Interpreter(model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
print(" input_details", input_details)
print("output_details",output_details)
# Test model on random input data.
input_shape = input_details['shape']
print("input_shape tflite",input_shape)
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details['index'], input_data)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details['index'])
print(output_data)因此,我无法确定转换后的模型是否已正确转换。
附注:我计划在android中使用此模型,但我对android(java或kotlin) tflite api几乎没有经验。如果任何人能指出任何学习资源,也将是有帮助的。
编辑:我也尝试了用java api在android上运行推理。但是会得到以下错误tensorflow/lite/kernels/gather.cc:80 0 <= axis && axis < NumDimensions(input). Detailed in this tensorflow issue
发布于 2020-09-06 00:05:50
您可以使用tflite python解释器验证您的自定义训练的TFLite模型。Reference
import numpy as np
import tensorflow as tf
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test the model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)https://stackoverflow.com/questions/61571375
复制相似问题