首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >PyTorch DataLoader

PyTorch DataLoader
EN

Stack Overflow用户
提问于 2017-12-21 22:31:54
回答 1查看 2.8K关注 0票数 2

我试图使用多个torch.utils.data.DataLoader来创建具有不同转换的数据集。目前,我的代码大致是

代码语言:javascript
复制
d_transforms = [
    transforms.RandomHorizontalFlip(),
    # Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
    dataset = datasets.MNIST('./data', 
            train=train, 
            download=True, 
            transform=d_transforms[i]
    loaders.append(
        DataLoader(dataset, 
            shuffle=True, 
            pin_memory=True, 
            num_workers=1)
        )

这很管用,但速度非常慢。核仁显示,在我的代码中几乎所有的时间都花在了类似的行上

代码语言:javascript
复制
x, y = next(iter(train_loaders[i]))

我怀疑这是因为我使用了多个DataLoader实例,每个实例都有自己的工作人员,这些实例试图读取相同的数据文件。

我的问题是,有什么更好的方法来做到这一点?理想情况下,我将子类torch.utils.data.DataSet,并指定我希望在采样时应用的转换,但由于__getitem__无法获取参数,这似乎是不可能的。

EN

回答 1

Stack Overflow用户

发布于 2017-12-26 22:55:17

__getitem__确实采用了一个参数,它是要加载的内容的索引。就像。

代码语言:javascript
复制
transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

class CountDataset(Dataset):

def __init__(self, file,transform=None):

    self.transform = transform
    #self.vocab = vocab
    with open(file,'rb') as f:
        self.data = pickle.load(f)
    self.y = self.data['answers']
    self.I = self.data['images']


def __len__(self):
    return len(self.y)

def __getitem__(self, idx):
    img_name = self.I[idx]
    label = self.y[Idx]
    fname = '/'.join(img_name.split("/")[-2:]) #/train2014/xx.jpg
    DIR = '/hdd/manoj/VQA/Images/mscoco/'
    img_full_path = os.path.join(DIR,fname)
    img = Image.open(img_full_path).convert("RGB")
    img_tensor = self.transform(img.resize((224,224)))
    return img_tensor,label


testset = CountDataset(file = 'testdat.pkl',
                        transform = transform)


testloader = DataLoader(testset, batch_size=32,
                         shuffle=False, num_workers=4)

您不需要在循环中调用数据加载器。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47933597

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档