首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么火炬剪枝不实际去除过滤器或重量?

为什么火炬剪枝不实际去除过滤器或重量?
EN

Stack Overflow用户
提问于 2021-09-24 08:28:06
回答 1查看 959关注 0票数 1

我使用一种架构,并试图通过剪枝来稀疏它。我编写了剪枝函数,以下是其中之一:

代码语言:javascript
复制
def prune_model_l1_unstructured(model, layer_type, proportion):
    for module in model.modules():
        if isinstance(module, layer_type):
            prune.l1_unstructured(module, 'weight', proportion)
            prune.remove(module, 'weight')
    return model

# prune model
prune_model_l1_unstructured(model, nn.Conv2d, 0.5)

它修剪一些权重(将它们更改为零)。但是prune.remove只删除原始权重,而保留零。参数的总数仍然相同(我检查过了)。模型的文件(model.pt)大小也一样。模型的“速度”在它之后仍然保持不变。我还尝试了全局剪枝和结构化L1剪枝,结果是一样的。那么,这如何有助于提高模型的性能时间呢?为什么重量不被移除,以及如何删除被修剪的连接?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-09-24 10:08:19

TLDR;PyTorch prune的功能只是作为一个重量掩码,这就是它所做的一切。没有与使用torch.nn.utils.prune相关的内存节省。

正如torch.nn.utils.prune.remove的文档页面所述:

从模块中移除剪枝重新参数化,从前向挂钩中移除剪枝方法。

实际上,这意味着它将从参数中移除prune.l1_unstructured添加的掩码。作为一个副作用,删除剪枝将意味着对以前隐藏的值有零,但这些值不会停留在0的值上。最终,PyTorch prune只会占用比不使用它更多的内存。所以这并不是你想要的功能。

您可以在这句话上阅读更多内容。

下面是一个示例:

代码语言:javascript
复制
>>> module = nn.Linear(10,3)
>>> prune.l1_unstructured(module, name='weight', amount=0.3)

重量参数被蒙蔽:

代码语言:javascript
复制
>>> module.weight
tensor([[-0.0000,  0.0000, -0.1397, -0.0942,  0.0000,  0.0000,  0.0000, -0.1452,
          0.0401,  0.1098],
        [ 0.2909, -0.0000,  0.2871,  0.1725,  0.0000,  0.0587,  0.0795, -0.1253,
          0.0764, -0.2569],
        [ 0.0000, -0.3054, -0.2722,  0.2414,  0.1737, -0.0000, -0.2825,  0.0685,
          0.1616,  0.1095]], grad_fn=<MulBackward0>)

这是面具:

代码语言:javascript
复制
>>> module.weight_mask
tensor([[0., 0., 1., 1., 0., 0., 0., 1., 1., 1.],
        [1., 0., 1., 1., 0., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 0., 1., 1., 1., 1.]])

注意,在应用prune.remove时,删除了剪枝。而且,蒙面值仍为零,但“未冻结”。

代码语言:javascript
复制
>>> prune.remove(module, 'weight')

>>> module.weight
tensor([[-0.0000,  0.0000, -0.1397, -0.0942,  0.0000,  0.0000,  0.0000, -0.1452,
          0.0401,  0.1098],
        [ 0.2909, -0.0000,  0.2871,  0.1725,  0.0000,  0.0587,  0.0795, -0.1253,
          0.0764, -0.2569],
        [ 0.0000, -0.3054, -0.2722,  0.2414,  0.1737, -0.0000, -0.2825,  0.0685,
          0.1616,  0.1095]], grad_fn=<MulBackward0>)

面具也不见了:

代码语言:javascript
复制
>>> hasattr(module, 'weight_mask')
False
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69311857

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档