首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >大型csv文件用火炬数据采集器.增量加载

大型csv文件用火炬数据采集器.增量加载
EN

Stack Overflow用户
提问于 2022-01-01 19:52:11
回答 1查看 1.9K关注 0票数 1

我正在尝试编写一个定制的torch数据加载器,这样就可以增量地加载大型CSV文件(通过块加载)。

我对如何做到这一点有一个粗略的想法。但是,我一直收到一些PyTorch错误,我不知道如何解决。

代码语言:javascript
复制
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

# Create dummy csv data
nb_samples = 110
a = np.arange(nb_samples)
df = pd.DataFrame(a, columns=['data'])
df.to_csv('data.csv', index=False)


# Create Dataset
class CSVDataset(Dataset):
    def __init__(self, path, chunksize, nb_samples):
        self.path = path
        self.chunksize = chunksize
        self.len = nb_samples / self.chunksize

    def __getitem__(self, index):
        x = next(
            pd.read_csv(
                self.path,
                skiprows=index * self.chunksize + 1,  #+1, since we skip the header
                chunksize=self.chunksize,
                names=['data']))
        x = torch.from_numpy(x.data.values)
        return x

    def __len__(self):
        return self.len


dataset = CSVDataset('data.csv', chunksize=10, nb_samples=nb_samples)
loader = DataLoader(dataset, batch_size=10, num_workers=1, shuffle=False)

for batch_idx, data in enumerate(loader):
    print('batch: {}\tdata: {}'.format(batch_idx, data))

我得到了'float' object cannot be interpreted as an integer错误

EN

回答 1

Stack Overflow用户

发布于 2022-01-02 09:45:40

此错误是由以下一行引起的:

代码语言:javascript
复制
self.len = nb_samples / self.chunksize

当使用/除法时,结果总是一个浮点数。但是您只能在__len__()函数中返回一个整数。因此,您必须舍入self.len和/或将其转换为整数。例如,只需这样做:

代码语言:javascript
复制
self.len = nb_samples // self.chunksize

双斜杠(//)向下舍入并转换为整数。

编辑:您可以在__len__()中返回一个浮点数,但是当调用len(dataset)时会出现错误。所以我猜len(dataset)是在DataLoader类的某个地方被调用的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70551454

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档