首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Pytorch argmax跨多维

Pytorch argmax跨多维
EN

Stack Overflow用户
提问于 2021-10-10 19:29:17
回答 2查看 279关注 0票数 1

我有一个4D张量,我想得到最后两个维度的argmax。torch.argmax只接受整数作为"dim“参数,而不接受元组。

我如何才能做到这一点呢?

这是我的想法,但我不知道如何匹配我的两个“索引”张量的维度。original_array形状为1,512,37,59。

代码语言:javascript
复制
max_vals, indices_r = torch.max(original_array, dim=2)
max_vals, indices_c = torch.max(max_vals, dim=2)
indices = torch.hstack((indices_r, indices_c))
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-10-10 20:19:50

正如其他人所提到的,最好的方法是将最后两个维度拉平并应用argmax

代码语言:javascript
复制
original_array = torch.rand(1, 512, 37, 59)
original_flatten = original_array.view(1, 512, -1)
_, max_ind = original_flatten.max(-1)

。。您将获得最大值的线性索引。如果您想要最大值的2D indecies,可以使用列数“取消扁平化”indecies

代码语言:javascript
复制
# 59 is the number of columns for the (37, 59) part
torch.stack([max_ind // 59, max_ind % 59], -1)

这将为您提供一个(1, 512, 2),其中每个最后2个dim都包含2D坐标。

票数 3
EN

Stack Overflow用户

发布于 2021-10-10 20:04:36

您可以使用torch.flatten展平最后两个维度,并对其应用torch.argmax

代码语言:javascript
复制
>>> x = torch.rand(2,3,100,100)
>>> x.flatten(-2).argmax(-1)
tensor([[2660, 6328, 8166],
        [5934, 5494, 9717]])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69518359

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档