您能告诉我为什么我没有导入CUFAR10DataModule()吗?
首先,我在GoogleColab上运行代码,
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule()然后,执行代码进行确认。
from torch.optim import Adam
optimizer = Adam(finetune_layer.parameters(), lr=1e-4)
for epoch in range(10):
for batch in dm.train_loader:
x, y = batch
with torch.no_grad():
features = backbone(x)
preds = finetune_layer(features)
loss = cross_entropy(preds, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(loss.item())但是,运行代码后返回消息AttributeError: 'CIFAR10DataModule' object has no attribute 'train_loader'。
当运行代码以确认dm时,
for batch in dm.train_dataloader:
x, y = batch
print(x.shape, y.shape)
break错误说是TypeError: 'method' object is not iterable。
代码与示例看起来一样,但我想知道为什么会产生这样的错误呢?
发布于 2021-08-29 12:54:24
代码有两个问题:
首先,获取底层PyTorch数据中心的方法是dm.train_dataloader()而不是dm.train_loader。它是一个函数,而不是一个属性。
for batch in dm.train_dataloader():
x, y = batch
...其次,由于您试图使用没有LightningDataModule的Trainer,所以需要手动调用
dm.prepare_data()
dm.setup()。。以便通过.train_dataloader()提供数据采集器。
https://stackoverflow.com/questions/68969811
复制相似问题