首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TFLITE无法将tensorflow模型的输入和输出量化为INT8

TFLITE无法将tensorflow模型的输入和输出量化为INT8
EN

Stack Overflow用户
提问于 2021-03-10 01:22:08
回答 1查看 139关注 0票数 0

我在将Tensorflow模型转换为TensorflowLite时遇到了一个问题。我想用量化转换整个模型,但当我完成这一步并可视化模型的体系结构时,我发现输入和输出仍然是浮动的。你能帮我解决这个问题吗?

版本信息: tensorflow 2.3.1 / python 3.6

用于验证的数据

代码语言:javascript
复制
validation_generator = validation_datagen.flow_from_directory(
    valid_data_dir,
    target_size=(img_width, img_height),
    classes=classes,
    batch_size=32,
    class_mode='categorical',
    )

模型的体系结构

代码语言:javascript
复制
model = Sequential()

model.add(Conv2D(32, (3, 3), padding='same', activation='relu', input_shape= (128,128,3)))

model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))


model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Dropout(0.3))

model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))

model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(layers.Dropout(0.4))

model.add(Flatten())  
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4, activation='softmax'))

model.summary()

训练/将模型转换为Tflite后

代码语言:javascript
复制
def representative_dataset_gen():
    for i in range(20):
        data_x, data_y = validation_generator.next()
        for data_xx in data_x:
            data = tf.reshape(data, shape=[-1, 128, 128, 3])
            yield [data]

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.representative_dataset = representative_dataset_gen

converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

converter.inference_input_tpye  = tf.int8

converter.inference_output_tpye = tf.int8

quantiz_model = converter.convert()

open("/content/drive/My Drive/model.tflite", "wb").write(quantiz_model)

model properties

EN

回答 1

Stack Overflow用户

发布于 2021-05-15 21:33:22

来自评论

看起来你打错了,应该是converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 (改写自daverim)

将模型转换为Tflite的工作代码

代码语言:javascript
复制
def representative_dataset_gen():
    for i in range(20):
        data_x, data_y = validation_generator.next()
        for data_xx in data_x:
            data = tf.reshape(data, shape=[-1, 128, 128, 3])
            yield [data]

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.representative_dataset = representative_dataset_gen

converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

converter.inference_input_type  = tf.int8

converter.inference_output_type = tf.int8

quantiz_model = converter.convert()

open("/content/drive/My Drive/model.tflite", "wb").write(quantiz_model)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66551794

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档