首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >是否可以使用PyTorch数据加载器加载保存在CSV文件中的原始数据图像?

是否可以使用PyTorch数据加载器加载保存在CSV文件中的原始数据图像?
EN

Stack Overflow用户
提问于 2021-08-07 12:59:03
回答 1查看 340关注 0票数 0

我将原始数据图像保存在单独的CSV文件中(每个图像在一个文件中)。我想用PyTorch训练一个美国有线电视新闻网。我应该如何加载数据以适合用作CNN的输入?(另外,它是1个通道,图像网络的输入是默认的RGB )

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-08-07 14:23:17

顾名思义,PyTorch的DataLoader只是一个工具类,它可以帮助您并行加载数据、构建批处理、混洗等等,而您需要的是一个自定义的Dataset实现。

忽略图像存储在CSV文件中有点奇怪的事实,你只需要这样的东西:

代码语言:javascript
复制
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):

    def __init__(self, path: Path, ...):
        # do some preliminary checks, e.g. your path exists, files are there...
        assert path.exists()
        ...
        # retrieve your files in some way, e.g. glob
        self.csv_files = list(glob.glob(str(path / "*.csv")))

    def __len__(self) -> int:
        # this lets you know len(dataset) once you instantiate it
        return len(self.csv_files)


    def __getitem__(self, index: int) -> Any:
        # this method is called by the dataloader, each index refers to
        # a CSV file in the list you built in the constructor
        csv = self.csv_files[index]
        # now do whatever you need to do and return some tensors
        image, label = self.load_image(csv)
        return image, label

就是这样,或多或少。然后,您可以创建数据集,将其传递给dataloader并迭代dataloader,如下所示:

代码语言:javascript
复制
dataset = CustomDataset(Path("path/to/csv/files"))
train_loader = DataLoader(dataset, shuffle=True, num_workers=8,...)

for batch in train_loader:
    ...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68692578

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档