首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >全量化除int8数据外,不能将模型输入层更改为int8。

全量化除int8数据外,不能将模型输入层更改为int8。
EN

Stack Overflow用户
提问于 2020-09-10 13:17:55
回答 1查看 585关注 0票数 0

我正在将角星h5模型量化为uint8。为了得到完全的uint8量化,用户dtlam26这个职位中告诉我,有代表性的数据集应该已经在uint8中了,否则输入层仍然在float32中。

问题是,如果我输入uint8数据,则在调用converter.convert()期间会收到以下错误

ValueError:不能设置张量:获取类型为INT8的张量,但输入FLOAT32类型为FLOAT32,名称: input_1

看起来,这个模型仍然期待着float32。所以我检查了基本的keras_vggface预训练模型(从这里开始)

代码语言:javascript
复制
from keras_vggface.vggface import VGGFace
import keras

pretrained_model = VGGFace(model='resnet50', include_top=False, input_shape=(224, 224, 3), pooling='avg')  # pooling: None, avg or max

pretrained_model.save()

得到的h5模型具有float32的输入层。接下来,我使用uint8作为输入dtype更改了模型定义:

代码语言:javascript
复制
def RESNET50(include_top=True, weights='vggface',
             ...)

    if input_tensor is None:
        img_input = Input(shape=input_shape, dtype='uint8')

但对于int,只允许使用int32。但是,使用int32会导致问题,下面的层期望使用float32。

这似乎不是为所有层手动执行此操作的正确方法。

为什么在量化过程中我的模型不包括uint8数据,并自动将输入更改为uint8?

我错过了什么?你知道解决办法吗?非常感谢。

EN

回答 1

Stack Overflow用户

发布于 2020-09-11 13:32:14

来自用户 dtlam26的解决方案

尽管模型仍然没有在google中运行,但是使用TF 1.15.3或TF2.2.0在int8中量化模型和输入输出的解决方案是,多亏了delan:

代码语言:javascript
复制
...
converter = tf.lite.TFLiteConverter.from_keras_model_file(saved_model_dir + modelname) 
        
def representative_dataset_gen():
  for _ in range(10):
    pfad='pathtoimage/000001.jpg'
    img=cv2.imread(pfad)
    img = np.expand_dims(img,0).astype(np.float32) 
    # Get sample input data as a numpy array in a method of your choosing.
    yield [img]
    
converter.representative_dataset = representative_dataset_gen

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.experimental_new_converter = True

converter.target_spec.supported_types = [tf.int8]
converter.inference_input_type = tf.int8 
converter.inference_output_type = tf.int8 
quantized_tflite_model = converter.convert()
if tf.__version__.startswith('1.'):
    open("test153.tflite", "wb").write(quantized_tflite_model)
if tf.__version__.startswith('2.'):
    with open("test220.tflite", 'wb') as f:
        f.write(quantized_tflite_model)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63830570

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档