我想对下面的do_permutation函数进行矢量化。我想删除python中的循环,改为使用numpy方法。
import numpy as np
def do_permutation(indices):
perm = np.zeros(len(indices), dtype='int32')
for i, o in enumerate(indices):
perm[o] = i
return perm
assert list(do_permutation([3, 2, 4, 1, 0])) == [4, 3, 1, 0, 2]发布于 2021-05-27 19:46:00
一般的解决方案是
do_perm = np.argsort但是,此解决方案是O(N log N)。对于足够长的索引,实际上是从0到N的排列,O(N)解决方案是值得的。在这种情况下,使用直接排序:
def do_perm(indices):
perm = np.empty_like(indices)
perm[indices] = np.arange(len(indices))
return perm最后一行也可以写成
np.put(perm, indices, np.arange(indices.size))使用np.put可以透明地处理解开的indices。
发布于 2021-05-27 19:39:07
您可以将其赋值给使用给定索引编制索引的数组的片段。
def do_permutation(indices):
N = len(indices);
perm = np.zeros(N, dtype='int32')
perm[indices] = np.arange(N)
return permhttps://stackoverflow.com/questions/67721253
复制相似问题