首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >火炬L1-规范剪枝是如何工作的?

火炬L1-规范剪枝是如何工作的?
EN

Stack Overflow用户
提问于 2021-12-14 09:15:34
回答 1查看 695关注 0票数 0

让我们看看我第一次得到的结果。这是我的模型的一个卷积层,我只显示了11个滤波器的权重(113x3带channel=1)。

左边是原来的重量,右边是修剪的重量。

所以我想知道"TORCH.NN.UTILS.PRUNE.L1_UNSTRUCTURED“是如何工作的,因为火把网站上说修剪了最低的L1-范数单位,但据我所知,L1-范数剪枝是一种过滤剪枝方法,它剪枝整个过滤器,使用这个方程式来细化最低的过滤器值,而不是修剪单个权重。所以我有点好奇这个函数是如何工作的?

以下是我的剪枝代码

代码语言:javascript
复制
parameters_to_prune = (
    (model.input_layer[0], 'weight'),
    (model.hidden_layer1[0], 'weight'),
    (model.hidden_layer2[0], 'weight'),
    (model.output_layer[0], 'weight')
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount = (pruned_percentage/100),
)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-12-14 09:52:11

nn.utils.prune.l1_unstructured实用程序并不会修剪整个过滤器,它会像您在工作表中看到的那样,对单个参数组件进行修剪。这是低范数的组件被蒙住了。

下面是一个很小的例子,在下面的注释中讨论过:

代码语言:javascript
复制
>>> m = nn.Linear(10,1,bias=False)
>>> m.weight = nn.Parameter(torch.arange(10).float())
>>> prune.l1_unstructured(m, 'weight', .3)
>>> m.weight
tensor([0., 0., 0., 3., 4., 5., 6., 7., 8., 9.], grad_fn=<MulBackward0>)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70346398

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档