首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在LSTM上剪枝增加型号尺寸?

在LSTM上剪枝增加型号尺寸?
EN

Stack Overflow用户
提问于 2021-04-07 13:31:22
回答 1查看 342关注 0票数 1

我是应用剪枝使用的torch.nn.utils.prune在一个模型与低LSTM层。但是,当我保存state_dict的内容时,模型要比剪枝前大得多。我不知道为什么,就像我在剪枝前后打印出state_dict元素的大小一样,所有的内容都是相同的维度,并且在state_dict中没有其他元素。

我的剪枝代码是非常标准的,我一定要调用prune.remove()

代码语言:javascript
复制
        model_state = model.state_dict()
        torch.save(model.state_dict(), 'pre_pruning.pth')
        for param_tensor in model_state:
            print(param_tensor, "\t", model_state[param_tensor].size())

        parameters_to_prune = []
        for param, _ in model.rnn.named_parameters():
            if "weight" in param:
                parameters_to_prune.append((model.rnn, param))
        prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.6)
        for module, param in parameters_to_prune:
            prune.remove(module, param)

        model_state = model.state_dict()
        torch.save(model_state, 'pruned.pth') # This file is much larger than the original
        for param_tensor in model_state:
            print(param_tensor, "\t", model_state[param_tensor].size())

当我试图修剪模型中的线性层时,保存的模型不会显示与修剪LSTM层时相同的大小增长。知道是什么导致的吗?

EN

回答 1

Stack Overflow用户

发布于 2022-08-16 13:20:42

这是因为剪枝引入了两个新的参数weight_origweight_mask,这实际上增加了模型的大小。

如果您希望删除这些参数,您应该使用torch.nn.utils.prune.remove()删除它们并停止剪枝过程(这意味着零权值在此步骤之后不会被冻结)。

有关详细信息,请查看剪枝教程

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66987160

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档