我在将一些代码从tensorflow移植到Py火炬时遇到了一些困难。
所以我有一个矩阵,尺寸为10x30,代表10个例子,每个例子都有30个特征。然后,我有另一个矩阵,维数为10x5,包含第一个矩阵中每一个例子最接近的5个例子的索引。我想用第二个矩阵中包含的指数来“收集”第一个矩阵中每个例子的5个壁橱例子,留给我一个10x5x30形状的三维张量。
在tensorflow中,这是用tf.gather(matrix1, matrix2)完成的。有人知道我怎么能在火把里做到这一点吗?
发布于 2018-12-10 19:55:24
这个怎么样?
matrix1 = torch.randn(10, 30)
matrix2 = torch.randint(high=10, size=(10, 5))
gathered = matrix1[matrix2]它使用了用整数数组进行索引的技巧。
发布于 2021-10-01 06:31:55
我有一个场景,必须将gather()应用于整数数组。
考试-01
torch.Tensor().gather(dim, input_tensor)# here,
# input_tensor -> tensor(1)
my_list = [0, 1, 2, 3, 4]
my_tensor = torch.IntTensor(my_list)
output = my_tensor.gather(0, input_tensor) # 0 -> is the dimension考试-02
torch.gather(param_tensor, dim, input_tensor)# here,
# input_tensor -> tensor(1)
my_list = [0, 1, 2, 3, 4]
my_tensor = torch.IntTensor(my_list)
output = torch.gather(my_tensor, 0, input_tensor) # 0 -> is the dimensionhttps://stackoverflow.com/questions/53697596
复制相似问题