首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >“CIFAR10DataModule”对象没有属性“train_loader”

“CIFAR10DataModule”对象没有属性“train_loader”
EN

Stack Overflow用户
提问于 2021-08-29 04:15:54
回答 1查看 916关注 0票数 1

您能告诉我为什么我没有导入CUFAR10DataModule()吗?

首先,我在GoogleColab上运行代码,

代码语言:javascript
复制
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule()

然后,执行代码进行确认。

代码语言:javascript
复制
from torch.optim import Adam
optimizer = Adam(finetune_layer.parameters(), lr=1e-4)

for epoch in range(10):
  for batch in dm.train_loader:
    x, y = batch
    with torch.no_grad():
      features = backbone(x)

    preds = finetune_layer(features)
    loss = cross_entropy(preds, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss.item())

但是,运行代码后返回消息AttributeError: 'CIFAR10DataModule' object has no attribute 'train_loader'

当运行代码以确认dm时,

代码语言:javascript
复制
for batch in dm.train_dataloader:
  x, y = batch
  print(x.shape, y.shape)
  break

错误说是TypeError: 'method' object is not iterable

代码与示例看起来一样,但我想知道为什么会产生这样的错误呢?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-08-29 12:54:24

代码有两个问题:

首先,获取底层PyTorch数据中心的方法是dm.train_dataloader()而不是dm.train_loader。它是一个函数,而不是一个属性。

代码语言:javascript
复制
for batch in dm.train_dataloader():
    x, y = batch
    ...

其次,由于您试图使用没有LightningDataModuleTrainer,所以需要手动调用

代码语言:javascript
复制
dm.prepare_data()
dm.setup()

。。以便通过.train_dataloader()提供数据采集器。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68969811

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档