首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于argmax的PyTorch索引

基于argmax的PyTorch索引
EN

Stack Overflow用户
提问于 2022-06-15 08:16:59
回答 1查看 683关注 0票数 0

亲爱的社区,我对PyTorch中的张量索引有一个挑战。问题很简单。给定一个张量,创建一个索引张量来索引每列的最大值。

代码语言:javascript
复制
x = T.tensor([[0, 3, 0, 5, 9, 8, 2, 0], 
              [0, 4, 9, 6, 7, 9, 1, 0]])

考虑到这个张量,我想构建一个布尔掩码,用于索引每个colum的最大值。具体来说,我不需要它的最大值torch.max(x, dim=0),也不需要它的索引torch.argmax(x, dim=0),而是基于这个张量最大值索引其他张量的布尔掩码。我的理想输出是:

代码语言:javascript
复制
# Input tensor
x
tensor([[0, 3, 0, 5, 9, 8, 2, 0],
        [0, 4, 9, 6, 7, 9, 1, 0]])

# Ideal output bool mask tensor
idx
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 1, 0, 0]])

我知道values_max = x[idx]values_max = x.max(dim=0)是等同的,但我不是在寻找values_max,而是在寻找idx

我已经围绕它构建了一个解决方案,但是它看起来很复杂,我相信torch有一个优化的方法来完成这个任务。我试图在torch.index_select的输出中使用x.argmax(dim=0),但失败了,所以我构建了一个自定义解决方案,对我来说似乎很麻烦,所以我请求帮助,以矢量化/张力/火炬的方式来实现这一点。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-15 08:27:12

您可以首先使用torch.argmax提取张量的最大值列的索引,然后将keepdim设置为True,从而执行此操作。

代码语言:javascript
复制
>>> x.argmax(0, keepdim=True)
tensor([[0, 1, 1, 1, 0, 1, 0, 0]])

然后可以使用torch.scatter1s置于指定索引处的零张量中:

代码语言:javascript
复制
>>> torch.zeros_like(x).scatter(0, x.argmax(0,True), value=1)
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
        [0, 1, 1, 1, 0, 1, 0, 0]])
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72628000

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档