我是应用剪枝使用的torch.nn.utils.prune在一个模型与低LSTM层。但是,当我保存state_dict的内容时,模型要比剪枝前大得多。我不知道为什么,就像我在剪枝前后打印出state_dict元素的大小一样,所有的内容都是相同的维度,并且在state_dict中没有其他元素。
我的剪枝代码是非常标准的,我一定要调用prune.remove()
model_state = model.state_dict()
torch.save(model.state_dict(), 'pre_pruning.pth')
for param_tensor in model_state:
print(param_tensor, "\t", model_state[param_tensor].size())
parameters_to_prune = []
for param, _ in model.rnn.named_parameters():
if "weight" in param:
parameters_to_prune.append((model.rnn, param))
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.6)
for module, param in parameters_to_prune:
prune.remove(module, param)
model_state = model.state_dict()
torch.save(model_state, 'pruned.pth') # This file is much larger than the original
for param_tensor in model_state:
print(param_tensor, "\t", model_state[param_tensor].size())当我试图修剪模型中的线性层时,保存的模型不会显示与修剪LSTM层时相同的大小增长。知道是什么导致的吗?
发布于 2022-08-16 13:20:42
这是因为剪枝引入了两个新的参数weight_orig和weight_mask,这实际上增加了模型的大小。
如果您希望删除这些参数,您应该使用torch.nn.utils.prune.remove()删除它们并停止剪枝过程(这意味着零权值在此步骤之后不会被冻结)。
有关详细信息,请查看剪枝教程。
https://stackoverflow.com/questions/66987160
复制相似问题