首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无广播的`torch.gather`

无广播的`torch.gather`
EN

Stack Overflow用户
提问于 2022-07-03 10:43:28
回答 1查看 149关注 0票数 0

我有一些分批输入的形状x [batch, time, feature]和一些分批指数i的形状[batch, new_time],我想收集到时间昏暗的x。作为此操作的输出,我需要一个形状为y的张量[batch, new_time, feature],其值如下:

代码语言:javascript
复制
y[b, t', f] = x[b, i[b, t'], f]

在Tensorflow中,我可以使用batch_dims: int argument of tf.gathery = tf.gather(x, i, axis=1, batch_dims=1)来实现这一点。

在PyTorch中,我可以想到一些类似的功能:

当然是

  1. torch.gather,但这与Tensorflow的batch_dims没有相似之处。torch.gather的输出总是与指数相同的形状。因此,我需要在将torch.gather.

传送给i之前,先将feature dim广播到i中。

但在这里,指数必须是一维的.因此,要使其工作,我需要取消广播x添加一个"batch * new_time“暗淡,然后在torch.index_select重塑输出。。

  1. torch.nn.functional.embedding.在这里,嵌入矩阵将对应于x。但是这个嵌入函数不支持批处理的权重,因此我遇到了与torch.index_select (查看代码,torch.index_select)相同的问题。

是否有可能在不依赖广播的情况下完成这样的采集操作,这对于大dims来说效率很低?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-03 11:35:28

这实际上是最常见的情况:当输入和索引张量不完全匹配维数时。但是,您仍然可以利用torch.gather,因为您可以重写表达式:

代码语言:javascript
复制
y[b, t, f] = x[b, i[b, t], f]

作为:

代码语言:javascript
复制
y[b, t, f] = x[b, i[b, t, f], f]

这确保了所有三个张量都有相同的维数。这揭示了i上的第三个维度,我们可以通过解压缩一个维度并将其扩展到x的形状来轻松地创建这个维度。您可以使用i[:,None].expand_as(x)来完成这个任务。

下面是一个很小的例子:

代码语言:javascript
复制
>>> b = 2; t = 3; f = 1
>>> x = torch.rand(b, t, f)
>>> i = torch.randint(0, t, (b, f))

>>> x.gather(1, i[:,None].expand_as(x))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72845808

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档