我用虹膜数据集来训练一个简单的网络。
trainset = iris.Iris(train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=150,
shuffle=True, num_workers=2)
dataiter = iter(trainloader)dataset本身只有150个数据点,py手电数据管理器在整个数据集上迭代jus一次,因为批处理大小为150。
我现在的问题是,如果一旦完成迭代,是否有任何方法可以告诉py手电筒的数据管理员在dataset上重复?
萨纳克斯
更新
)刚刚创建了一个dataloader的子类并实现了我自己的__next__()
发布于 2019-11-15 12:15:59
来补充之前的答案。为了在数据集之间进行比较,通常最好使用步骤总数而不是作为超参数的历元总数。这是因为迭代次数不应与数据集大小有关,而应与其复杂性有关。
我正在使用以下代码进行培训。它确保数据加载器每次重新启动数据时都会重新调整数据。
# main training loop
generator = iter(trainloader)
for i in range(max_steps):
try:
# Samples the batch
x, y = next(generator)
except StopIteration:
# restart the generator if the previous generator is exhausted.
generator = iter(trainloader)
x, y = next(generator)我会同意,这不是最优雅的解决方案,但它使我不必依赖时代的训练。
发布于 2018-04-23 18:23:16
使用itertools.cycle有一个重要的缺点,因为它不会在每次迭代之后对数据进行洗牌:
当可迭代耗尽时,从保存的副本中返回元素。
在某些情况下,这会对模型的性能产生负面影响。解决这一问题的方法可以是编写您自己的循环生成器:
def cycle(iterable):
while True:
for x in iterable:
yield x你会把它用作:
dataiter = iter(cycle(trainloader))发布于 2017-12-08 14:20:49
最简单的选择是只使用嵌套循环:
for i in range(10):
for batch in trainloader:
do_something(batch)另一种选择是使用itertools.cycle,也许与itertools.take结合使用。
当然,使用批处理大小等于整个数据集的DataLoader有点不寻常。您也不需要在列车加载程序上调用iter()。
https://stackoverflow.com/questions/47714643
复制相似问题