假设我有一个张量,利用torch.topk函数得到张量的最大k元及其指数。类似于下面的代码
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1., 2., 3., 4., 5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))现在,假设我想将上面的最大k元素设置为0,但是保持它们在原始张量中的相同位置。我怎么能这么做?
如果k =3,那么新的张量应该如下所示:
张量( 1.,2.,0.,0.,0.)
基本上,我如何使用这些max k元素( topk() torch函数的返回)的标识符(将它们的值设置为0) --这些位置中的原始值?
我更喜欢用火炬法来做我想要的事情。如果不这样做,最好是尽可能有效地解决问题。
(谢谢你预先提供的帮助:)
发布于 2022-11-26 08:30:05
torch.topk函数返回提供维度上top-k元素的索引。可以使用torch.scatter执行重新分配操作。
>>> _, i = x.topk(k=3)
>>> x.scatter(dim=0, index=i, value=0)
tensor([1., 2., 0., 0., 0.])https://stackoverflow.com/questions/74577249
复制相似问题