亲爱的社区,我对PyTorch中的张量索引有一个挑战。问题很简单。给定一个张量,创建一个索引张量来索引每列的最大值。
x = T.tensor([[0, 3, 0, 5, 9, 8, 2, 0],
[0, 4, 9, 6, 7, 9, 1, 0]])考虑到这个张量,我想构建一个布尔掩码,用于索引每个colum的最大值。具体来说,我不需要它的最大值torch.max(x, dim=0),也不需要它的索引torch.argmax(x, dim=0),而是基于这个张量最大值索引其他张量的布尔掩码。我的理想输出是:
# Input tensor
x
tensor([[0, 3, 0, 5, 9, 8, 2, 0],
[0, 4, 9, 6, 7, 9, 1, 0]])
# Ideal output bool mask tensor
idx
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 0, 0]])我知道values_max = x[idx]和values_max = x.max(dim=0)是等同的,但我不是在寻找values_max,而是在寻找idx。
我已经围绕它构建了一个解决方案,但是它看起来很复杂,我相信torch有一个优化的方法来完成这个任务。我试图在torch.index_select的输出中使用x.argmax(dim=0),但失败了,所以我构建了一个自定义解决方案,对我来说似乎很麻烦,所以我请求帮助,以矢量化/张力/火炬的方式来实现这一点。
发布于 2022-06-15 08:27:12
您可以首先使用torch.argmax提取张量的最大值列的索引,然后将keepdim设置为True,从而执行此操作。
>>> x.argmax(0, keepdim=True)
tensor([[0, 1, 1, 1, 0, 1, 0, 0]])然后可以使用torch.scatter将1s置于指定索引处的零张量中:
>>> torch.zeros_like(x).scatter(0, x.argmax(0,True), value=1)
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 0, 0]])https://stackoverflow.com/questions/72628000
复制相似问题