首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch数据加载器多次迭代

pytorch数据加载器多次迭代
EN

Stack Overflow用户
提问于 2017-12-08 12:41:12
回答 4查看 29.6K关注 0票数 9

我用虹膜数据集来训练一个简单的网络。

代码语言:javascript
复制
trainset = iris.Iris(train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=150,
                                          shuffle=True, num_workers=2)

dataiter = iter(trainloader)

dataset本身只有150个数据点,py手电数据管理器在整个数据集上迭代jus一次,因为批处理大小为150。

我现在的问题是,如果一旦完成迭代,是否有任何方法可以告诉py手电筒的数据管理员在dataset上重复?

萨纳克斯

更新

)刚刚创建了一个dataloader的子类并实现了我自己的__next__()

EN

回答 4

Stack Overflow用户

发布于 2019-11-15 12:15:59

来补充之前的答案。为了在数据集之间进行比较,通常最好使用步骤总数而不是作为超参数的历元总数。这是因为迭代次数不应与数据集大小有关,而应与其复杂性有关。

我正在使用以下代码进行培训。它确保数据加载器每次重新启动数据时都会重新调整数据。

代码语言:javascript
复制
# main training loop
    generator = iter(trainloader)
    for i in range(max_steps):

        try:
            # Samples the batch
            x, y = next(generator)
        except StopIteration:
            # restart the generator if the previous generator is exhausted.
            generator = iter(trainloader)
            x, y = next(generator)

我会同意,这不是最优雅的解决方案,但它使我不必依赖时代的训练。

票数 8
EN

Stack Overflow用户

发布于 2018-04-23 18:23:16

使用itertools.cycle有一个重要的缺点,因为它不会在每次迭代之后对数据进行洗牌:

当可迭代耗尽时,从保存的副本中返回元素。

在某些情况下,这会对模型的性能产生负面影响。解决这一问题的方法可以是编写您自己的循环生成器:

代码语言:javascript
复制
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

你会把它用作:

代码语言:javascript
复制
dataiter = iter(cycle(trainloader))
票数 6
EN

Stack Overflow用户

发布于 2017-12-08 14:20:49

最简单的选择是只使用嵌套循环:

代码语言:javascript
复制
for i in range(10):
    for batch in trainloader:
        do_something(batch)

另一种选择是使用itertools.cycle,也许与itertools.take结合使用。

当然,使用批处理大小等于整个数据集的DataLoader有点不寻常。您也不需要在列车加载程序上调用iter()。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47714643

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档