我正在构建一个应该对花进行分类的模型。所以我用Tensorflow创建了一个模型:
keras.layers.Conv2D(128, (3,3), activation='relu', input_shape=(imageShape[0], imageShape[1],3)),
keras.layers.MaxPooling2D(2,2),
keras.layers.Dropout(0.5),
keras.layers.Conv2D(256, (3,3), activation='relu'),
keras.layers.MaxPooling2D(2,2),
keras.layers.Conv2D(512, (3,3), activation='relu'),
keras.layers.MaxPooling2D(2,2),
keras.layers.Flatten(),
keras.layers.Dropout(0.3),
keras.layers.Dense(280, activation='relu'),
keras.layers.Dense(4, activation='softmax')
opt = tf.keras.optimizers.RMSprop()
model.compile(loss='categorical_crossentropy',
optimizer= opt,
metrics=['accuracy'])在训练期间,我将检查点保存为.h5
checkpoint = ModelCheckpoint("preSaved"+str(time.time())+".h5", monitor='val_loss', verbose=1,
save_best_only=True, save_weights_only=False, mode='auto', period=1)现在,我得到了一个损失相当低的时期,并希望将其转换为.tflite上传到Firebase (在Android Studio App中使用它)。
import tensorflow as tf
new_model= tf.keras.models.load_model(filepath="model.h5")
tflite_converter = tf.lite.TFLiteConverter.from_keras_model(new_model)
tflite_converter.inference_type=tf.uint8
tflite_converter.default_ranges_stats=[min_value,max_value]
tflite_converter.quantized_input_stats={"conv2d_6_input_6:0"[mean,std]}
tflite_converter.post_training_quantize=True
tflite_model = tflite_converter.convert()
open("tf_lite_model.tflite", "wb").write(tflite_model).h5大约有335mb,最终的.tflite got 160mb.But Firebase只允许.tflite到60MB,如果我使用本地型号,它需要几分钟才能加载。我读到.tflite通常更小。在我的模型中或者当我将它转换为.tflite时有问题吗?
https://stackoverflow.com/questions/63435756
复制相似问题