首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Pytorch DataLoader迭代顺序稳定吗?

Pytorch DataLoader迭代顺序稳定吗?
EN

Stack Overflow用户
提问于 2019-12-13 07:35:00
回答 2查看 4.7K关注 0票数 1

Pytorch Dataloader的迭代顺序是否保证相同(在温和条件下)?

例如:

代码语言:javascript
复制
dataloader = DataLoader(my_dataset, batch_size=4,
                        shuffle=True, num_workers=4)
print("run 1")
for batch in dataloader:
  print(batch["index"])

print("run 2")
for batch in dataloader:
  print(batch["index"])

到目前为止,我已经测试过了,它似乎不是固定的,两次运行的顺序都是一样的。有没有办法让顺序一致呢?谢谢

编辑:我也尝试过这样做

代码语言:javascript
复制
unlabeled_sampler = data.sampler.SubsetRandomSampler(unlabeled_indices)
unlabeled_dataloader = data.DataLoader(train_dataset, 
                sampler=unlabeled_sampler, batch_size=args.batch_size, drop_last=False)

然后通过数据加载器迭代两次,但同样的非确定性结果。

EN

回答 2

Stack Overflow用户

发布于 2019-12-14 03:38:15

简短的答案是否定的,当shuffle=True时,DataLoader的迭代顺序在迭代之间是不稳定的。每次在加载器上迭代时,内部RandomSampler都会创建一个新的随机顺序。

获得稳定的随机DataLoader的一种方法是使用一组随机索引创建一个Subset数据集。

代码语言:javascript
复制
shuffled_dataset = torch.utils.data.Subset(my_dataset, torch.randperm(len(my_dataset)).tolist())
dataloader = DataLoader(shuffled_dataset, batch_size=4, num_workers=4, shuffled=False)
票数 5
EN

Stack Overflow用户

发布于 2019-12-16 08:18:31

我实际上选择了jodag在评论中的回答:

代码语言:javascript
复制
torch.manual_seed("0")

for i,elt in enumerate(unlabeled_dataloader):
    order.append(elt[2].item())
    print(elt)

    if i > 10:
        break

torch.manual_seed("0")

print("new dataloader")
for i,elt in enumerate( unlabeled_dataloader):
    print(elt)
    if i > 10:
        break
exit(1)                       

和输出:

代码语言:javascript
复制
[tensor([[-0.3583, -0.6944]]), tensor([3]), tensor([1610])]
[tensor([[-0.6623, -0.3790]]), tensor([3]), tensor([1958])]
[tensor([[-0.5046, -0.6399]]), tensor([3]), tensor([1814])]
[tensor([[-0.5349,  0.2365]]), tensor([2]), tensor([1086])]
[tensor([[-0.1310,  0.1158]]), tensor([0]), tensor([321])]
[tensor([[-0.2085,  0.0727]]), tensor([0]), tensor([422])]
[tensor([[ 0.1263, -0.1597]]), tensor([0]), tensor([142])]
[tensor([[-0.1387,  0.3769]]), tensor([1]), tensor([894])]
[tensor([[-0.0500,  0.8009]]), tensor([3]), tensor([1924])]
[tensor([[-0.6907,  0.6448]]), tensor([4]), tensor([2016])]
[tensor([[-0.2817,  0.5136]]), tensor([2]), tensor([1267])]
[tensor([[-0.4257,  0.8338]]), tensor([4]), tensor([2411])]
new dataloader
[tensor([[-0.3583, -0.6944]]), tensor([3]), tensor([1610])]
[tensor([[-0.6623, -0.3790]]), tensor([3]), tensor([1958])]
[tensor([[-0.5046, -0.6399]]), tensor([3]), tensor([1814])]
[tensor([[-0.5349,  0.2365]]), tensor([2]), tensor([1086])]
[tensor([[-0.1310,  0.1158]]), tensor([0]), tensor([321])]
[tensor([[-0.2085,  0.0727]]), tensor([0]), tensor([422])]
[tensor([[ 0.1263, -0.1597]]), tensor([0]), tensor([142])]
[tensor([[-0.1387,  0.3769]]), tensor([1]), tensor([894])]
[tensor([[-0.0500,  0.8009]]), tensor([3]), tensor([1924])]
[tensor([[-0.6907,  0.6448]]), tensor([4]), tensor([2016])]
[tensor([[-0.2817,  0.5136]]), tensor([2]), tensor([1267])]
[tensor([[-0.4257,  0.8338]]), tensor([4]), tensor([2411])]

这就是我们想要的。然而,我认为jodag的主要答案仍然更好;这只是一个目前有效的快速技巧;)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59314174

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档