首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >电筒等效tf.gather

电筒等效tf.gather
EN

Stack Overflow用户
提问于 2018-12-09 23:05:45
回答 2查看 2.9K关注 0票数 1

我在将一些代码从tensorflow移植到Py火炬时遇到了一些困难。

所以我有一个矩阵,尺寸为10x30,代表10个例子,每个例子都有30个特征。然后,我有另一个矩阵,维数为10x5,包含第一个矩阵中每一个例子最接近的5个例子的索引。我想用第二个矩阵中包含的指数来“收集”第一个矩阵中每个例子的5个壁橱例子,留给我一个10x5x30形状的三维张量。

在tensorflow中,这是用tf.gather(matrix1, matrix2)完成的。有人知道我怎么能在火把里做到这一点吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-12-10 19:55:24

这个怎么样?

代码语言:javascript
复制
matrix1 = torch.randn(10, 30)
matrix2 = torch.randint(high=10, size=(10, 5))
gathered = matrix1[matrix2]

它使用了用整数数组进行索引的技巧。

票数 6
EN

Stack Overflow用户

发布于 2021-10-01 06:31:55

我有一个场景,必须将gather()应用于整数数组。

考试-01

代码语言:javascript
复制
torch.Tensor().gather(dim, input_tensor)
代码语言:javascript
复制
# here,
#   input_tensor -> tensor(1)
my_list = [0, 1, 2, 3, 4]
my_tensor = torch.IntTensor(my_list)
output = my_tensor.gather(0, input_tensor) # 0 -> is the dimension

考试-02

代码语言:javascript
复制
torch.gather(param_tensor, dim, input_tensor)
代码语言:javascript
复制
# here,
#   input_tensor -> tensor(1)
my_list = [0, 1, 2, 3, 4]
my_tensor = torch.IntTensor(my_list)
output = torch.gather(my_tensor, 0, input_tensor) # 0 -> is the dimension
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53697596

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档