首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >torch.gather的倒数

torch.gather的倒数
EN

Stack Overflow用户
提问于 2021-07-24 12:46:19
回答 1查看 725关注 0票数 1

给定一个输入张量x和一个指数idxs张量,我想检索索引不存在于idxs中的x的所有元素。也就是说,接受与torch.gather函数输出相反的输出。

使用torch.gather的示例

代码语言:javascript
复制
>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1,  2,  3],
        [14, 15, 16],
        [27, 28, 29]])

我真正想要实现的是

代码语言:javascript
复制
tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

它可以是一个有效和高效的实现,可能使用torch实用程序吗?我不想用任何的循环。

我假设idxs在其最深的维度中只有唯一的元素。例如,idxs将是调用torch.topk的结果。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-24 14:08:57

您可能需要构造形状(x.size(0), x.size(1)-idxs.size(1))的张量(此处为(3, 7))。它对应于idxs的互补指数,关于x的形状,即

代码语言:javascript
复制
tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

我建议首先建立一个张量,形状类似于x,它揭示了我们想要保持的位置和我们想要丢弃的位置,一种面具。这可以使用torch.scatter来完成。这在本质上分散了0在需要的位置,即m[i, idxs[i][j]] = 0

代码语言:javascript
复制
>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])

然后获取非零(idxs的互补部分)。选择axis=1上的第二个指数,并根据目标张量进行整形:

代码语言:javascript
复制
>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

现在你知道该怎么做了对吧?与您给出的torch.gather示例相同,但这次使用idxs_

代码语言:javascript
复制
>>> torch.gather(x, 1, idxs_)
tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

总结如下:

代码语言:javascript
复制
>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
        .nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))

>>> torch.gather(x, 1, idxs_)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68510107

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档