我想保存PyTorch的torch.utils.data.dataloader.DataLoader实例,这样我就可以继续训练我离开的地方(保留随机种子、状态和所有东西)。
发布于 2020-04-04 16:16:12
这很简单。人们应该设计自己的Sampler,它获取起始索引,并自动对数据进行混洗:
import random
from torch.utils.data.dataloader import Sampler
random.seed(224) # use a fixed number
class MySampler(Sampler):
def __init__(self, data, i=0):
random.shuffle(data)
self.seq = list(range(len(data)))[i * batch_size:]
def __iter__(self):
return iter(self.seq)
def __len__(self):
return len(self.seq)现在将最后一个索引i保存在某个地方,下次使用它实例化DataLoader:
train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
sampler=train_sampler,
shuffle=False) # don't forget to set DataLoader's shuffle to False在Colab上进行训练时,它非常有用。
发布于 2020-08-25 18:53:14
您需要一个采样器的自定义实现。可以从以下位置轻松使用:https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5
您可以保存和恢复,如下所示:
sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)
for x in loader:
print(x)
break
sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)
for x in loader2:
print(x)发布于 2021-07-20 22:11:03
but considered for future improvements,对此的原生PyTorch支持仍然不可用。不过,请参阅有关自定义构建的其他答案。
https://stackoverflow.com/questions/60993677
复制相似问题