我的数据集的__getitem__函数返回一个torch.stft() M x N x D张量,其中N是长度可变的音频输入序列。每一项都在__getitem__函数中读取。我想要沿着第二个维度(N)连接批次。因此,通过迭代数据加载器,我将得到如下形式的数据:M x (N X batch_size) x D。这个问题有可能的解决方案吗?
发布于 2021-06-23 00:44:57
您可以使用传递给DataLoader的自定义collate函数来完成此操作:
import torch
from torch.utils.data import DataLoader
M = 20
D = 12
N = 30
a = torch.rand((M,N,D))
b = torch.rand((M,N,D))
def my_collate(batch):
c = torch.stack(batch, dim=1)
return c.permute(0, 2, 1, 3)
c = my_collate([a,b]) # output shape MxNxBxD-> torch.Size([20, 30, 2, 12])然后传递给DataLoader:
loader = DataLoader(dataset=datasetObject, batch_size=1, collate_fn=my_collate)https://stackoverflow.com/questions/68087353
复制相似问题