首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch闪电(可训练对撞机-错)

PyTorch闪电(可训练对撞机-错)
EN

Stack Overflow用户
提问于 2022-04-18 18:22:44
回答 3查看 470关注 0票数 0

我采用多GPU训练使用电筒雷电。下面的输出显示模型:

代码语言:javascript
复制
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
┏━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    ┃ Name       ┃ Type              ┃ Params ┃
┡━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0  │ encoder    │ Encoder           │  2.0 M │
│ 1  │ classifier │ Sequential        │  8.8 K │
│ 2  │ criterion  │ BCEWithLogitsLoss │      0 │
│ 3  │ train_acc  │ Accuracy          │      0 │
│ 4  │ val_acc    │ Accuracy          │      0 │
│ 5  │ train_auc  │ AUROC             │      0 │
│ 6  │ val_auc    │ AUROC             │      0 │
│ 7  │ train_f1   │ F1Score           │      0 │
│ 8  │ val_f1     │ F1Score           │      0 │
│ 9  │ train_mcc  │ MatthewsCorrCoef  │      0 │
│ 10 │ val_mcc    │ MatthewsCorrCoef  │      0 │
│ 11 │ train_sens │ Recall            │      0 │
│ 12 │ val_sens   │ Recall            │      0 │
│ 13 │ train_spec │ Specificity       │      0 │
│ 14 │ val_spec   │ Specificity       │      0 │
└────┴────────────┴───────────────────┴────────┘
Trainable params: 2.0 M
Non-trainable params: 0

我已将编码器设置为无法使用以下代码:

代码语言:javascript
复制
ckpt = torch.load(chk_path)
self.encoder.load_state_dict(ckpt['state_dict'])
self.encoder.requires_grad = False

trainable params不应该是8.8 K而不是2.0 M吗?

我的优化器如下:

代码语言:javascript
复制
optimizer =  torch.optim.RMSprop(filter(lambda p: p.requires_grad, self.parameters()), lr =self.lr, weight_decay = self.weight_decay)
EN

回答 3

Stack Overflow用户

发布于 2022-07-20 17:00:51

您需要为所有编码器参数逐一设置requires_grad=False

代码语言:javascript
复制
for param in self.encoder.parameters():
    param.requires_grad = False
票数 0
EN

Stack Overflow用户

发布于 2022-07-20 22:30:52

self.encoder.requires_grad = False什么也不做;事实上,torch模块没有requires_grad标志。

您应该做的是使用requires_grad_方法(注意第二个下划线),该方法将该模块的所有参数设置为所需的值:

代码语言:javascript
复制
self.encoder.requires_grad_(False)

如下所述:https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.requires_grad_

票数 0
EN

Stack Overflow用户

发布于 2022-07-29 11:51:19

请注意,如果执行以下代码:

代码语言:javascript
复制
class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

mnist_model = MNISTModel()
mnist_model.l2.requires_grad = False
print(mnist_model.l2.weight.requires_grad)
print(mnist_model.l2.bias.requires_grad)
ModelSummary(mnist_model) 

你会得到:

代码语言:javascript
复制
True
True

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 1.2 M 
1 | l2   | Linear | 2.5 M 
2 | l3   | Linear | 15.7 K
--------------------------------
3.7 M     Trainable params
0         Non-trainable params
3.7 M     Total params
14.827    Total estimated model params size (MB)

这意味着这实际上并不是对该层中的参数停用requires_grad。因此,您有两个选项,根据(https://pytorch.org/docs/stable/notes/autograd.html#setting-requires-grad)

  1. .requires_grad_()应用于@burzam (更正确的一个)

建议的模块

代码语言:javascript
复制
mnist_model = MNISTModel()
mnist_model.l2.requires_grad_(False)
ModelSummary(mnist_model)
代码语言:javascript
复制
  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 1.2 M 
1 | l2   | Linear | 2.5 M 
2 | l3   | Linear | 15.7 K
--------------------------------
1.2 M     Trainable params
2.5 M     Non-trainable params
3.7 M     Total params
14.827    Total estimated model params size (MB)

  1. 循环遍历模块

中的参数

代码语言:javascript
复制
mnist_model = MNISTModel()
for param in mnist_model.l2.parameters():
    param.requires_grad = False

ModelSummary(mnist_model)

你会看到:

代码语言:javascript
复制
  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 1.2 M 
1 | l2   | Linear | 2.5 M 
2 | l3   | Linear | 15.7 K
--------------------------------
1.2 M     Trainable params
2.5 M     Non-trainable params
3.7 M     Total params
14.827    Total estimated model params size (MB)

对于要停用的特定层中的所有参数,需要将requires_grad设置为False

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71915570

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档