首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用INT8量化tflite模型运行推理

用INT8量化tflite模型运行推理
EN

Stack Overflow用户
提问于 2021-03-23 16:23:44
回答 1查看 1.3K关注 0票数 0

**大家好,我最近把tensorflow浮子模型转换成了tflite量化的INT8模型,最后我得到了模型,没有错误。我想在python中使用这个模型进行推论,但是我不能得到好的结果。代码如下:**

变换TF模型

代码语言:javascript
复制
 def representative_dataset_gen():
    for i in range(20):
        data_x, data_y = validation_generator.next()
        for data_xx in data_x:
            data = tf.reshape(data, shape=[-1, 128, 128, 3])
            yield [data]

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.representative_dataset = representative_dataset_gen

converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

converter.inference_input_type  = tf.int8

converter.inference_output_type = tf.int8

quantized_model = converter.convert()

open("/content/drive/My Drive/model.tflite", "wb").write(quantized_model)

运行推理

代码语言:javascript
复制
tflite_file='./model_google.tflite'
img_name='./img_test/1_2.jpg'

test_image = load_img(img_name, target_size=(128, 128))
test_image = img_to_array(test_image)

test_image = test_image.reshape(1, 128, 128,3)
#test_image = test_image.astype('float32')


interpreter = tf.lite.Interpreter(model_path=(tflite_file))
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()[0]


input_scale, input_zero_point = input_details['quantization']


test_image_int = test_image / input_scale + input_zero_point
test_image_int=test_image_int.astype(input_details['dtype'])




interpreter.set_tensor(input_details['index'], test_image_int)
interpreter.invoke()

output_details = interpreter.get_output_details()[0]

output = interpreter.get_tensor(output_details['index'])

scale, zero_point= output_details['quantization']

tflite_output=output.astype(np.float32)
tflite_output= (tflite_output- zero_point)* scale

print(input_scale)
print(tflite_output)
print(input_details["quantization"])

,你能告诉我如何用这个量化的模型(输入和输出转换成INT8)来预测一个类,并且有正确的概率值吗?

EN

回答 1

Stack Overflow用户

发布于 2021-03-24 08:40:04

Hi Jae,谢谢您的回答,附件是具有代表性的数据集代码:

代码语言:javascript
复制
train_datagen =  ImageDataGenerator(
    rescale=1. / 255,
    rotation_range=30,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=[0.6, 1.1],
    horizontal_flip=True,
    brightness_range=[0.8, 1.3],
    channel_shift_range=2.0,
    fill_mode='nearest')

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    classes=classes,
    class_mode='categorical',
    )
def representative_dataset_gen():
    for i in range(10):
        data_x, data_y = train_generator.next()
        for data in data_x:
            data = tf.reshape(data, shape=[-1, 128, 128, 3])
            yield [data]

我使用训练数据集中的数据进行量化,您能告诉我如何在发送到输入之前进行图像处理,以及如何在输出端读取推论谢谢

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66767195

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档