首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >pytorch dataloader:沿着dataloader输出的一个维度连接批处理

pytorch dataloader:沿着dataloader输出的一个维度连接批处理
EN

Stack Overflow用户
提问于 2021-06-23 00:24:31
回答 1查看 257关注 0票数 0

我的数据集的__getitem__函数返回一个torch.stft() M x N x D张量,其中N是长度可变的音频输入序列。每一项都在__getitem__函数中读取。我想要沿着第二个维度(N)连接批次。因此,通过迭代数据加载器,我将得到如下形式的数据:M x (N X batch_size) x D。这个问题有可能的解决方案吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-06-23 00:44:57

您可以使用传递给DataLoader的自定义collate函数来完成此操作:

代码语言:javascript
复制
import torch
from torch.utils.data import DataLoader

M = 20
D = 12
N = 30
a = torch.rand((M,N,D))
b = torch.rand((M,N,D))

def my_collate(batch):
    c = torch.stack(batch, dim=1)
    return c.permute(0, 2, 1, 3)

c = my_collate([a,b]) # output shape  MxNxBxD-> torch.Size([20, 30, 2, 12])

然后传递给DataLoader:

代码语言:javascript
复制
loader = DataLoader(dataset=datasetObject, batch_size=1, collate_fn=my_collate)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68087353

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档