我有一些分批输入的形状x [batch, time, feature]和一些分批指数i的形状[batch, new_time],我想收集到时间昏暗的x。作为此操作的输出,我需要一个形状为y的张量[batch, new_time, feature],其值如下:
y[b, t', f] = x[b, i[b, t'], f]在Tensorflow中,我可以使用batch_dims: int argument of tf.gather:y = tf.gather(x, i, axis=1, batch_dims=1)来实现这一点。
在PyTorch中,我可以想到一些类似的功能:
当然是
torch.gather,但这与Tensorflow的batch_dims没有相似之处。torch.gather的输出总是与指数相同的形状。因此,我需要在将torch.gather.传送给i之前,先将feature dim广播到i中。
但在这里,指数必须是一维的.因此,要使其工作,我需要取消广播x添加一个"batch * new_time“暗淡,然后在torch.index_select重塑输出。。
torch.nn.functional.embedding.在这里,嵌入矩阵将对应于x。但是这个嵌入函数不支持批处理的权重,因此我遇到了与torch.index_select (查看代码,torch.index_select)相同的问题。是否有可能在不依赖广播的情况下完成这样的采集操作,这对于大dims来说效率很低?
发布于 2022-07-03 11:35:28
这实际上是最常见的情况:当输入和索引张量不完全匹配维数时。但是,您仍然可以利用torch.gather,因为您可以重写表达式:
y[b, t, f] = x[b, i[b, t], f]作为:
y[b, t, f] = x[b, i[b, t, f], f]这确保了所有三个张量都有相同的维数。这揭示了i上的第三个维度,我们可以通过解压缩一个维度并将其扩展到x的形状来轻松地创建这个维度。您可以使用i[:,None].expand_as(x)来完成这个任务。
下面是一个很小的例子:
>>> b = 2; t = 3; f = 1
>>> x = torch.rand(b, t, f)
>>> i = torch.randint(0, t, (b, f))
>>> x.gather(1, i[:,None].expand_as(x))https://stackoverflow.com/questions/72845808
复制相似问题