首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么在对一个AssertionError模型进行静态量化时,我得到了'AssertionError:没有找到用于:‘的fuser方法?

为什么在对一个AssertionError模型进行静态量化时,我得到了'AssertionError:没有找到用于:‘的fuser方法?
EN

Stack Overflow用户
提问于 2022-03-02 12:48:56
回答 1查看 679关注 0票数 0

当我试图在模型上应用静态量化时,我得到了下面的错误。错误出现在代码的fuse部分:torch.quantization.fuse_modules(model, modules_to_fuse)

代码语言:javascript
复制
model = torch.quantization.fuse_modules(model, modules_to_fuse)
  File "/Users/celik/PycharmProjects/GFPGAN/colorization/lib/python3.8/site-packages/torch/ao/quantization/fuse_modules.py", line 146, in fuse_modules
    _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
  File "/Users/celik/PycharmProjects/GFPGAN/colorization/lib/python3.8/site-packages/torch/ao/quantization/fuse_modules.py", line 77, in _fuse_modules
    new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
  File "/Users/celik/PycharmProjects/GFPGAN/colorization/lib/python3.8/site-packages/torch/ao/quantization/fuse_modules.py", line 45, in fuse_known_modules
    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
  File "/Users/celik/PycharmProjects/GFPGAN/colorization/lib/python3.8/site-packages/torch/ao/quantization/fuser_method_mappings.py", line 132, in get_fuser_method
    assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
AssertionError: did not find fuser method for: (<class 'torch.nn.modules.conv.Conv2d'>,) 
EN

回答 1

Stack Overflow用户

发布于 2022-03-02 13:28:08

modules_to_fuse列表应遵守以下规则:

代码语言:javascript
复制
Fuses only the following sequence of modules:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, relu
    bn, relu
    All other sequences are left unchanged.
    For these sequences, replaces the first item in the list
    with the fused module, replacing the rest of the modules
    with identity.

我不能融合'torch.nn.modules.conv.Conv2d'的模型。它应该与诸如"conv,bn“或"conv,bn,relu”或“conv,relu”之类的其他组合融合在一起。使用上面的列表准备您的融合列表。对我起作用了。

这里还有另一个融合方法列表:

代码语言:javascript
复制
DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
(nn.Linear, nn.ReLU): nni.LinearReLU,
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,}
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71322979

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档