首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何获得topk的值及其指数(2D)?

如何获得topk的值及其指数(2D)?
EN

Stack Overflow用户
提问于 2021-03-11 12:50:16
回答 1查看 187关注 0票数 0

我有两个三维张量,我想用一个人的前k个索引,得到另一个前k。

例如,对于以下张量

代码语言:javascript
复制
a = torch.tensor([[[1], [2], [3]],
                  [[4], [5], [6]]])

b = torch.tensor([[[7,1], [8,2], [9,3]],
                  [[10,4],[11,5],[12,6]]])

pytorch的topk函数将为我提供以下内容。

代码语言:javascript
复制
top_tensor, indices = torch.topk(a, 2, dim=1)

# top_tensor: tensor([[[3], [2]],
#                    [[6],  [5]]])

# indices: tensor([[[2], [1]],
#                 [[2],  [1]]])

但是我想用a的结果,映射到b。

代码语言:javascript
复制
# use indices to do something for b, get torch.tensor([[[8,2], [9,3]],
#                                                      [[11,5],[12,6]]])

在这种情况下,我不知道b的真实值,所以我不能对b使用topk。

换句话说,我想得到一个函数foo_slice,如下所示:

代码语言:javascript
复制
top_tensor, indices = torch.topk(a, 2, dim=1)
# top_tensor == foo_slice(a, indices)

有什么方法可以使用pytorch实现这一点吗?

谢谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-03-11 14:21:09

您正在寻找的解决方案是here

因此,您的问题的基于代码的解决方案如下

代码语言:javascript
复制
#inputs are changed in order from the above ques

a = torch.tensor([[[1], [2], [3]],
                  [[5], [6], [4]]])

b = torch.tensor([[[7,1], [8,2], [9,3]],
                  [[11,5],[12,6],[10,4]]])

top_tensor, indices = torch.topk(a, 2, dim=1)

v = [indices.view(-1,2)[i] for i in range(0,indices.shape[1])]


new_tensor = []
for i,f in enumerate(v):
      new_tensor.append(torch.index_select(b[i], 0, f))
print(new_tensor ) #[tensor([[9, 3],
                   #         [8, 2]]),
                   #tensor([[12,  6],
                   #        [11,  5]])]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66576655

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档