给定一个输入张量x和一个指数idxs张量,我想检索索引不存在于idxs中的x的所有元素。也就是说,接受与torch.gather函数输出相反的输出。
使用torch.gather的示例
>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1, 2, 3],
[14, 15, 16],
[27, 28, 29]])我真正想要实现的是
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])它可以是一个有效和高效的实现,可能使用torch实用程序吗?我不想用任何的循环。
我假设idxs在其最深的维度中只有唯一的元素。例如,idxs将是调用torch.topk的结果。
发布于 2021-07-24 14:08:57
您可能需要构造形状(x.size(0), x.size(1)-idxs.size(1))的张量(此处为(3, 7))。它对应于idxs的互补指数,关于x的形状,即
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])我建议首先建立一个张量,形状类似于x,它揭示了我们想要保持的位置和我们想要丢弃的位置,一种面具。这可以使用torch.scatter来完成。这在本质上分散了0在需要的位置,即m[i, idxs[i][j]] = 0。
>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])然后获取非零(idxs的互补部分)。选择axis=1上的第二个指数,并根据目标张量进行整形:
>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])现在你知道该怎么做了对吧?与您给出的torch.gather示例相同,但这次使用idxs_
>>> torch.gather(x, 1, idxs_)
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])总结如下:
>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
>>> torch.gather(x, 1, idxs_)https://stackoverflow.com/questions/68510107
复制相似问题