首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch急切量化跳过模块

pytorch急切量化跳过模块
EN

Stack Overflow用户
提问于 2022-06-23 13:24:37
回答 1查看 159关注 0票数 0

我正在使用急切的模式量化。然而,我想跳过一些层从量化。我正在学习教程这里

但是,当我现在测试该模型时,我得到了以下错误:

无法使用“aten::_slow_conv2d_forward”后端的参数运行“QuantizedCPU”。

如果我正确理解的话,这是因为qconfig = none的层在接收量化数据的同时,也在等待去量化的数据。有没有一种方法,我可以添加指令,在层之前去量化数据,在层之后,在循环中量化数据?或者我还能做什么其他的解决办法呢?

排除层的代码:

代码语言:javascript
复制
for quantized_layer, _ in fused_model.named_modules():
   if (quantized_layer in sortedSensitivityDict):
      if sortedSensitivityDict[quantized_layer] > 0.94:
        _.qconfig = torch.quantization.get_default_qconfig("qnnpack")
      else:
        _.qconfig = None

量化代码:

代码语言:javascript
复制
import torch.optim as optim
model_fp32_prepared = torch.quantization.prepare(fused_model)

def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

calibrate(model_fp32_prepared, val_loader)
model_fp32_prepared.eval()
model_int8 = torch.quantization.convert(model_fp32_prepared)

主要问题是我使用的是MobileNetV3,其中前向函数如下所示:

代码语言:javascript
复制
def _forward_impl(self, x: Tensor) -> Tensor:
      x = self.features(x)
      x = self.avgpool(x)
      x = torch.flatten(x, 1)
      x = self.classifier(x)

由于这些层都在self.features中,所以我不知道如何使用self.quantself.dequant

EN

回答 1

Stack Overflow用户

发布于 2022-06-23 15:56:59

这里的博客作者-这可能是相当棘手的渴望模式不幸。我们有一个新的API使用FX图形模式,使这些操作更容易。您不需要设置每个模块的qconfig,而是可以传递一个具有要禁用的层名的dict。

类似于:

代码语言:javascript
复制
disable_layers = []
for quantized_layer, _ in fused_model.named_modules():
   if (quantized_layer in sortedSensitivityDict):
      if sortedSensitivityDict[quantized_layer] > 0.94:
          disable_layers.append(quantized_layer)

qconfig_dict = {
    # Global Config
    "": torch.quantization.get_default_qconfig("qnnpack"),

    # Disable by layer-name
    "module_name": [(m, None) for m in disable_layers],

    # Or disable by layer-type
    "object_type": [
        (torch.nn.functional.add, None),  # skips quantization for all functional.add layers
        ...,
        ],
}

model_fp32_prepared = torch.quantization.quantize_fx.prepare_fx(model, qconfig_dict)

# calibrate as usual

model_int8 = torch.quantization.quantize_fx.convert_fx(model_fp32_prepared)

FYR,我有一个笔记本正在浏览这个工作流:Workflow.ipynb

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72730969

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档