首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我们能把线性和BatchNorm参数融合吗?

我们能把线性和BatchNorm参数融合吗?
EN

Stack Overflow用户
提问于 2020-09-02 13:19:24
回答 1查看 415关注 0票数 0

我想从我的模型中删除BatchNorm。所以,我想把它和线性融合。我的模型结构是:

  • Linear -> ReLU -> BatchNorm -> Dropout -> Linear

我尝试过融合BatchNorm -> Linear,但是我无法与可用的代码融合。是否有任何方法将BatchNorm与上述任何一层融合。

EN

回答 1

Stack Overflow用户

发布于 2020-09-03 06:56:57

代码语言:javascript
复制
class DummyModule_1(nn.Module):
    def __init__(self):
        super(DummyModule_1, self).__init__()

    def forward(self, x):
        # print("Dummy, Dummy.")
        return x


def fuse_1(linear, bn):
    w = linear.weight
    print(w.size())
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)

    beta = bn.weight
    gamma = bn.bias

    if linear.bias is not None:
        b = linear.bias
    else:
        b = mean.new_zeros(mean.shape)

    w = w.cuda()
    b = b.cuda()
    w = w * (beta / var_sqrt).reshape([4096, 1])
    b = (b - mean)/var_sqrt * beta + gamma
    fused_linear = nn.Linear(linear.in_features,
                         linear.out_features)
                                             
    fused_linear.weight = nn.Parameter(w)
    fused_linear.bias = nn.Parameter(b)
    return fused_linear


def fuse_module_1(m):
    children = list(m.named_children())
    c = None
    cn = None
    global c1
    global count
    global c18

    for name, child in children:
        print("name is",name,"child is",child)
       
         

        if name == 'linear':
          count = count+1 
          
          if count == 2:
            c18 = child
            print("c18 is",c18)

          else:
            fuse_module_1(child)

        if name =='2' and isinstance(child,nn.BatchNorm1d):
          print("child is",child)
          bc = fuse_1(c18,child)
          m.classifier[1].linear = bc
          m.classifier[2] = DummyModule_1(


        else:
            #fuse_module_1(child)
          fuse_module_1(child)```
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63706428

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档