首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >DataLoader类误差割炬

DataLoader类误差割炬
EN

Stack Overflow用户
提问于 2019-02-27 04:38:09
回答 1查看 6.6K关注 0票数 2

我是一个初学者,并且我正在尝试使用dataloader。

实际上,我正在尝试将这个实现到我的网络中,但是加载需要很长的时间。所以,我对我的网络进行了调试,看看网络本身是否有问题,但事实证明它与我的dataloader类有关。以下是代码:

代码语言:javascript
复制
 from torch.utils.data import Dataset, DataLoader
 import numpy as np
 import pandas as pd

class DiabetesDataset(Dataset):

  def __init__(self, csv):
      self.xy = pd.read_csv(csv)

  def __len__(self):
      return len(self.xy)

  def __getitem__(self, index):
       self.x_data = torch.Tensor(xy.iloc[:, 0:-1].values)
       self.y_data = torch.Tensor(xy.iloc[:, [-1]].values)
       return self.x_data[index], self.y_data[index]

 dataset = DiabetesDataset("trial.csv")
 train_loader = DataLoader(dataset=dataset,
                      batch_size=1,
                      shuffle=True,
                      num_workers=2)`

 for a in train_loader:
    print(a)

为了验证dataloader导致了所有延迟,我创建了一个包含1s和2s列的虚拟csv文件,每个列总共有10个示例。然后,我在train_loader对象上循环,它已经超过了1小时,而且它仍然在运行,考虑到样本大小很小,并且批处理大小设置为1。

我不确定我的代码的错误是什么,它导致了这个问题。

非常感谢您的任何评论/输入!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-02-27 05:05:16

您的代码中有一些bugs -您能检查一下这个程序是否有效吗(它在我的计算机上使用您的玩具示例):

代码语言:javascript
复制
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import torch


class DiabetesDataset(Dataset):

    def __init__(self, csv):
        self.xy = pd.read_csv(csv)

    def __len__(self):
        return len(self.xy)

    def __getitem__(self, index):
        x_data = torch.Tensor(self.xy.iloc[:, 0:-1].values)
        y_data = torch.Tensor(self.xy.iloc[:, [-1]].values)
        return x_data[index], y_data[index]


dataset = DiabetesDataset("trial.csv")


train_loader = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=True,
    num_workers=2)

if __name__ == '__main__':
    for a in train_loader:
        print(a)

编辑:您的代码无法工作,因为您在__getitem__方法(self.xy.iloc.)中缺少了一个self因为在脚本的末尾没有一个if __name__ == '__main__。有关第二个错误,请参见RuntimeError on windows trying python multiprocessing

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54898145

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档