我在将Tensorflow模型转换为TensorflowLite时遇到了一个问题。我想用量化转换整个模型,但当我完成这一步并可视化模型的体系结构时,我发现输入和输出仍然是浮动的。你能帮我解决这个问题吗?
版本信息: tensorflow 2.3.1 / python 3.6
用于验证的数据
validation_generator = validation_datagen.flow_from_directory(
valid_data_dir,
target_size=(img_width, img_height),
classes=classes,
batch_size=32,
class_mode='categorical',
)模型的体系结构
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same', activation='relu', input_shape= (128,128,3)))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(Dropout(0.3))
model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(AveragePooling2D(pool_size=(2, 2)))
model.add(layers.Dropout(0.4))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4, activation='softmax'))
model.summary()训练/将模型转换为Tflite后
def representative_dataset_gen():
for i in range(20):
data_x, data_y = validation_generator.next()
for data_xx in data_x:
data = tf.reshape(data, shape=[-1, 128, 128, 3])
yield [data]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_tpye = tf.int8
converter.inference_output_tpye = tf.int8
quantiz_model = converter.convert()
open("/content/drive/My Drive/model.tflite", "wb").write(quantiz_model)发布于 2021-05-15 21:33:22
来自评论
看起来你打错了,应该是
converter.inference_input_type = tf.int8converter.inference_output_type = tf.int8(改写自daverim)
将模型转换为Tflite的工作代码
def representative_dataset_gen():
for i in range(20):
data_x, data_y = validation_generator.next()
for data_xx in data_x:
data = tf.reshape(data, shape=[-1, 128, 128, 3])
yield [data]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops =[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
quantiz_model = converter.convert()
open("/content/drive/My Drive/model.tflite", "wb").write(quantiz_model)https://stackoverflow.com/questions/66551794
复制相似问题