首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >维度上每个张量的PyTorch张量topk

维度上每个张量的PyTorch张量topk
EN

Stack Overflow用户
提问于 2021-04-01 22:34:41
回答 2查看 95关注 0票数 0

我有以下张量

代码语言:javascript
复制
inp = tensor([[[ 0.0000e+00,  5.7100e+02, -6.9846e+00],
     [ 0.0000e+00,  4.4070e+03, -7.1008e+00],
     [ 0.0000e+00,  3.0300e+02, -7.2226e+00],
     [ 0.0000e+00,  6.8000e+01, -7.2777e+00],
     [ 1.0000e+00,  5.7100e+02, -6.9846e+00],
     [ 1.0000e+00,  4.4070e+03, -7.1008e+00],
     [ 1.0000e+00,  3.0300e+02, -7.2226e+00],
     [ 1.0000e+00,  6.8000e+01, -7.2777e+00]],

    [[ 0.0000e+00,  2.1610e+03, -7.0754e+00],
     [ 0.0000e+00,  6.8000e+01, -7.2259e+00],
     [ 0.0000e+00,  1.0620e+03, -7.2920e+00],
     [ 0.0000e+00,  2.9330e+03, -7.3009e+00],
     [ 1.0000e+00,  2.1610e+03, -7.0754e+00],
     [ 1.0000e+00,  6.8000e+01, -7.2259e+00],
     [ 1.0000e+00,  1.0620e+03, -7.2920e+00],
     [ 1.0000e+00,  2.9330e+03, -7.3009e+00]],

    [[ 0.0000e+00,  4.4070e+03, -7.1947e+00],
     [ 0.0000e+00,  3.5600e+02, -7.2958e+00],
     [ 0.0000e+00,  3.0300e+02, -7.3232e+00],
     [ 0.0000e+00,  1.2910e+03, -7.3615e+00],
     [ 1.0000e+00,  4.4070e+03, -7.1947e+00],
     [ 1.0000e+00,  3.5600e+02, -7.2958e+00],
     [ 1.0000e+00,  3.0300e+02, -7.3232e+00],
     [ 1.0000e+00,  1.2910e+03, -7.3615e+00]]])

形状的

代码语言:javascript
复制
torch.Size([3, 8, 3])

我希望在dim1中找到topk(k=4)元素,其中排序的值是dim2 (负值)。由此得到的张量形状应该是:

代码语言:javascript
复制
torch.Size([3, 4, 3])

我知道如何对单个张量执行topk,但如何一次对多个批次执行此操作?

EN

回答 2

Stack Overflow用户

发布于 2021-04-01 23:45:27

我是这样做的:

代码语言:javascript
复制
val, ind = inp[:, :, 2].squeeze().topk(k=4, dim=1, sorted=True)
new_ind = ind.unsqueeze(-1).repeat(1,1,3)
result = inp.gather(1, new_ind)

我不知道这是不是最好的方法,但它是有效的。

票数 1
EN

Stack Overflow用户

发布于 2021-04-01 23:51:00

实现此目的的一种方法是将fancy indexingbroadcasting组合在一起,如下所示:

我以形状为(3, 4, 3)k为2的随机张量x为例。

代码语言:javascript
复制
>>> import torch
>>> x = torch.rand(3, 4, 3)
>>> x
tensor([[[0.0256, 0.7366, 0.2528],
         [0.5596, 0.9450, 0.5795],
         [0.8265, 0.5469, 0.8304],
         [0.4223, 0.5206, 0.2898]],

        [[0.2159, 0.0369, 0.6869],
         [0.4556, 0.5804, 0.3169],
         [0.8194, 0.5240, 0.0055],
         [0.8357, 0.4162, 0.3740]],

        [[0.3849, 0.0223, 0.9951],
         [0.2872, 0.5952, 0.6570],
         [0.1433, 0.8450, 0.6557],
         [0.0270, 0.9176, 0.3904]]])

现在沿着所需的维度(这里是最后一个)对张量进行排序,并获得索引:

代码语言:javascript
复制
>>> _, idx = torch.sort(x[:, :, -1])
>>> k = 2
>>> idx = idx[:, :k]
# idx is = 
tensor([[0, 3],
        [2, 1],
        [3, 2]])

现在生成三对索引(i, j, k)来对原始张量进行切片,如下所示:

代码语言:javascript
复制
>>> i = torch.arange(x.shape[0]).reshape(x.shape[0], 1, 1)
>>> j = idx.reshape(x.shape[0], -1, 1)
>>> k = torch.arange(x.shape[2]).reshape(1, 1, x.shape[2])

请注意,一旦您通过(i, j, k)对任何内容进行索引,它们将转到expand并采用形状(x.shape[0], k, x.shape[2]),这是此处所需的输出形状。现在只需通过i,j和k对x进行索引:

代码语言:javascript
复制
>>> x[i, j, k]
tensor([[[0.0256, 0.7366, 0.2528],
         [0.4223, 0.5206, 0.2898]],

        [[0.8194, 0.5240, 0.0055],
         [0.4556, 0.5804, 0.3169]],

        [[0.0270, 0.9176, 0.3904],
         [0.1433, 0.8450, 0.6557]]])

基本上,我遵循的一般方法是通过索引数组创建张量的相应访问模式,然后使用这些数组作为索引直接对张量进行切片。

我这样做实际上是为了升序排序,所以这里我得到的是top-k个最少的元素。要获得相反的结果,一个简单的解决方法是使用torch.sort(x[:, :, -1], descending = True)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66906505

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档