首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch DataLoader随机播放

PyTorch DataLoader随机播放
EN

Stack Overflow用户
提问于 2020-04-09 14:17:04
回答 2查看 9.8K关注 0票数 4

我做了一个实验,并没有得到我所期望的结果。

对于第一部分,我使用

代码语言:javascript
复制
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=False, num_workers=0)

在训练模型之前,我将trainloader.dataset.targets保存到变量a,将trainloader.dataset.data保存到变量b。然后,我使用trainloader训练模型。

训练完成后,我将trainloader.dataset.targets保存到变量c,将trainloader.dataset.data保存到变量d。最后,我检查了a == cb == d,它们都给出了True,这是预期的,因为DataLoader的shuffle参数是False

对于第二部分,我将使用

代码语言:javascript
复制
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=True, num_workers=0)

在训练模型之前,我将trainloader.dataset.targets保存到变量e,将trainloader.dataset.data保存到变量f。然后,我使用trainloader训练模型。训练完成后,我将trainloader.dataset.targets保存到变量g,将trainloader.dataset.data保存到变量h。我预计e == gf == hshuffle=True之后都会成为False,但他们又给了True。我在DataLoader类的定义中遗漏了什么?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-04-09 15:16:13

我相信直接存储在trainloader.dataset.data或.target中的数据不会被打乱,只有当DataLoader作为生成器或迭代器被调用时,数据才会被打乱

你可以通过执行next(iter(训练加载器))来检查它,不使用洗牌和使用洗牌,它们应该会给出不同的结果

代码语言:javascript
复制
import torch
import torchvision

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        ])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = False,
                                         num_workers = 10)
target = dataLoader.dataset.targets


MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)

dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = True,
                                         num_workers = 10)

target_shuffled = dataLoader_shuffled.dataset.targets

print(target == target_shuffled)

_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))

print(target == target_shuffled)

这将提供:

代码语言:javascript
复制
tensor([True, True, True,  ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False,  True, False,
        False, False, False, False, False, False, False, False])

然而,数据和目标中存储的数据和标签是一个固定的列表,由于您正尝试直接访问它,因此它们不会被打乱。

票数 3
EN

Stack Overflow用户

发布于 2020-10-27 15:09:43

我在使用Dataset类加载数据时遇到了类似的问题。我不再使用Dataset类加载数据,而是使用以下代码,它对我来说工作得很好

代码语言:javascript
复制
X = torch.from_numpy(X)
y = torch.from_numpy(y)

train_data = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)

其中X和Y是csv文件中的数值数组。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61115032

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档