首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >ORL数据集Pytorch数据集输入数据

ORL数据集Pytorch数据集输入数据
EN

Stack Overflow用户
提问于 2022-10-30 13:19:44
回答 1查看 67关注 0票数 0

我试图在Py手电中建立一个神经网络来识别著名的Olivetti人脸数据集(ORL数据集)中的人脸。图像的维数为32x32=1024,其中40040类。我将数据集从.mat文件传输到Python熟悉的变量环境中。

代码语言:javascript
复制
orl = loadmat('ORL_32x32.mat')
x = orl["fea"]
y = orl["gnd"]
df = pd.DataFrame(x)
df_label = pd.DataFrame(y) 
df.to_csv("data.csv", index = False)
df_label.to_csv("y.csv", index = False)

在此之后,我执行了以下代码

代码语言:javascript
复制
label = torchvision.transforms.functional.to_tensor(df_label.values) #shape torch.Size([1, 400, 1])
df_tensor = torchvision.transforms.functional.to_tensor(df.values)  #shape torch.Size([1, 400, 1024])

之后,我创建了一个张量数据集,并开始通过历代进行训练。

代码语言:javascript
复制
trn = TensorDataset(df_tensor,label)
#print(type(trn))
trn_dataloader = torch.utils.data.DataLoader(trn,batch_size=400,shuffle=False, num_workers=4)
for epoch in range(EPOCHS):

  for batch_idx, (data, target) in enumerate(trn_dataloader):   
         print(data.shape)   #torch.Size([1, 400, 1024])

这实际上是一个大问题-因为data.shape应该是torch.Size(1,1,1024),只是一个图像,而不是整个数据集看起来像一个图像。

解决整个问题的最好方法是什么?

EN

回答 1

Stack Overflow用户

发布于 2022-10-30 19:12:13

您已经指定数据中心的批处理大小为400,您说这是数据集中的图像数量。因此,dataloader循环中的data张量将包含所有图像。如果将批处理大小设置为1,则会看到数据将具有形状(1, 1, 1024)

根据您培训模型的方式,您将相应地调整批处理大小,但通常不会使用1作为批处理大小进行培训。

由于使用了PyTorch,我建议将数据重组为图像的标准方式,即(batch size, number of channels, height, width)。看起来您正在处理扁平的图像,因此形状应该是(batch size, number of features)

在我看来,您的data.csv似乎有一些错误的安排,以正确的方式加载。加载时,它会混合通道大小和批处理大小。但这可以通过改变张量来解决:

代码语言:javascript
复制
df_tensor = df_tensor.permute(1, 0, 2) # Shape: (1, 400, 1024) -> (400, 1, 1024)

或者取消通道维度,因为这些图像是扁平的:

代码语言:javascript
复制
df_tensor = df_tensor.squeeze(0) # Shape: (1, 400, 1024) -> (400, 1024)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74253403

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档