首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何保存PyTorch的DataLoader实例?

如何保存PyTorch的DataLoader实例?
EN

Stack Overflow用户
提问于 2020-04-02 22:14:48
回答 3查看 3.2K关注 0票数 3

我想保存PyTorch的torch.utils.data.dataloader.DataLoader实例,这样我就可以继续训练我离开的地方(保留随机种子、状态和所有东西)。

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2020-04-04 16:16:12

这很简单。人们应该设计自己的Sampler,它获取起始索引,并自动对数据进行混洗:

代码语言:javascript
复制
import random
from torch.utils.data.dataloader import Sampler


random.seed(224)  # use a fixed number


class MySampler(Sampler):
    def __init__(self, data, i=0):
        random.shuffle(data)
        self.seq = list(range(len(data)))[i * batch_size:]

    def __iter__(self):
        return iter(self.seq)

    def __len__(self):
        return len(self.seq)

现在将最后一个索引i保存在某个地方,下次使用它实例化DataLoader

代码语言:javascript
复制
train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,                                                         
                               batch_size=batch_size, 
                               sampler=train_sampler,
                               shuffle=False)  # don't forget to set DataLoader's shuffle to False

在Colab上进行训练时,它非常有用。

票数 2
EN

Stack Overflow用户

发布于 2020-08-25 18:53:14

您需要一个采样器的自定义实现。可以从以下位置轻松使用:https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5

您可以保存和恢复,如下所示:

代码语言:javascript
复制
sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)

for x in loader:
    print(x)
    break

sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)

for x in loader2:
    print(x)
票数 3
EN

Stack Overflow用户

发布于 2021-07-20 22:11:03

but considered for future improvements,对此的原生PyTorch支持仍然不可用。不过,请参阅有关自定义构建的其他答案。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60993677

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档