首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么火把闪电configure_optimizer扔AssertionError: param组一定是一个小块?

为什么火把闪电configure_optimizer扔AssertionError: param组一定是一个小块?
EN

Stack Overflow用户
提问于 2022-02-15 16:29:33
回答 2查看 256关注 0票数 0

我已经建立了多个火把闪电项目在过去和建立一个新的快速演示项目,我偶然发现这个奇怪的错误,不知怎么我无法摆脱它。

这是我的模型文件的相关部分。

代码语言:javascript
复制
class TSModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 10, kernel_size=(3, 3), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )

        self.classifier = nn.Sequential(
            nn.Linear(10*16*16, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        N = x.shape[0]
        x = self.backbone(x)
        x = x.view(N, -1)
        return self.classifier(x)

    def configure_optimizers(self):
        params = [p for p in self.parameters() if p.requires_grad]
        return torch.optim.AdamW(self.parameters())

但是,在启动培训过程时,程序将退出,并引发以下内容:

代码语言:javascript
复制
Traceback (most recent call last):
  File "/torchserve-example/main.py", line 25, in <module>
    ts_train()
  File "/torchserve-example/main.py", line 21, in ts_train
    trainer.fit(model, datamodule)
  File ".local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
    self._run(model)
  File ".local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 715, in _run
    self.accelerator.setup(self, model)  # note: this sets up self.lightning_module
  File ".local/lib/python3.8/site-packages/pytorch_lightning/accelerators/cpu.py", line 39, in setup
    return super().setup(trainer, model)
  File ".local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in setup
    self.setup_optimizers(trainer)
  File ".local/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 374, in setup_optimizers
    optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
  File ".local/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 190, in init_optimizers
    return trainer.init_optimizers(model)
  File ".local/lib/python3.8/site-packages/pytorch_lightning/trainer/optimizers.py", line 34, in init_optimizers
    optim_conf = model.configure_optimizers()
  File "/torchserve-example/model.py", line 52, in configure_optimizers
    return torch.optim.AdamW(self.parameters())
  File ".local/lib/python3.8/site-packages/torch/optim/adamw.py", line 47, in __init__
    super(AdamW, self).__init__(params, defaults)
  File ".local/lib/python3.8/site-packages/torch/optim/optimizer.py", line 55, in __init__
    self.add_param_group(param_group)
  File ".local/lib/python3.8/site-packages/torch/optim/optimizer.py", line 242, in add_param_group
    assert isinstance(param_group, dict), "param group must be a dict"
AssertionError: param group must be a dict

当我在print(type(params[0]))中执行configure_optimizers时,它会将<class 'torch.nn.parameter.Parameter'>打印到stdout。知道这里出了什么问题吗?

注意:由于这个错误发生在优化器的初始化过程中,这很可能与火把闪电没有直接关系,这也是为什么我也将py手电作为一个标签。

EN

回答 2

Stack Overflow用户

发布于 2022-02-15 16:43:32

在图书馆代码中,我发现:

代码语言:javascript
复制
# if not isinstance(param_groups[0], dict):
#             param_groups = [{'params': param_groups}]

在注释这一点时,一切正常工作。我没有回答这个问题,因为更改基础库或将此代码部分复制到我的文件中并不是一个很好的解决方案。

票数 0
EN

Stack Overflow用户

发布于 2022-02-15 19:13:52

实际上,这一行代码是错误的:

代码语言:javascript
复制
def configure_optimizers(self):
        params = [p for p in self.parameters() if p.requires_grad]
        return torch.optim.AdamW(self.parameters())

您传递的是params,而不是self.parameters(),因为这样做很好。

使用像这样创建的params,本质上是传递带有生成器的列表,不是dict的实例。

在PyTorch中,可以通过列表中包含的dicts传递具有不同学习速率等的多个不同参数。在PyTorch的API中,这就是您的params看起来的样子。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71130015

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档