首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >计算一个二维numpy数组中包含另一个一维数组的所有元素的所有行的最佳方法是什么?

计算一个二维numpy数组中包含另一个一维数组的所有元素的所有行的最佳方法是什么?
EN

Stack Overflow用户
提问于 2019-07-30 23:42:06
回答 4查看 215关注 0票数 0

计算一个二维numpy数组中包含另一个一维numpy数组的所有值的行数的最佳方法是什么?第二个数组的列数可以比一维数组的长度多。

代码语言:javascript
复制
elements = np.arange(4).reshape((2, 2))
test_elements = [2, 3]
somefunction(elements, test_elements)

我期望函数返回1。

代码语言:javascript
复制
elements = np.arange(15).reshape((5, 3))

# array([[ 0,  1,  2],
#       [ 3,  4,  5],
#       [ 6,  7,  8],
#       [ 9, 10, 11],
#       [12, 13, 14]])

test_elements = [4, 3]
somefunction(elements, test_elements)

也应该返回1。

必须包括一维数组的所有元素。如果在一行中只找到几个元素,则不算数。因此:

代码语言:javascript
复制
elements = np.arange(15).reshape((5, 3))

# array([[ 0,  1,  2],
#       [ 3,  4,  5],
#       [ 6,  7,  8],
#       [ 9, 10, 11],
#       [12, 13, 14]])

test_elements = [3, 4, 10]
somefunction(elements, test_elements)

也应该返回0。

EN

回答 4

Stack Overflow用户

发布于 2019-07-31 00:01:04

创建找到的元素的布尔数组,然后按行使用。这将避免在同一行中有多个值,最后使用sum对行进行计数。

代码语言:javascript
复制
np.any(np.isin(elements, test), axis=1).sum()

输出

代码语言:javascript
复制
>>> elements
array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11],
       [12, 13, 14]])
>>> test = [1, 6, 7, 4]
>>> np.any(np.isin(elements, test), axis=1).sum()
3
票数 0
EN

Stack Overflow用户

发布于 2019-07-31 00:02:31

(编辑:好的,现在我实际上有了更多的时间来弄清楚到底是怎么回事。)

这里有两个问题:

  1. 计算复杂度取决于两个输入的大小,1D benchmark plot
  2. 不能很好地捕捉到这一点,实际计时由inputs

的变化决定

这个问题可以分为两个部分:

遍历子集检查的循环检查,这基本上是嵌套循环二次操作(在最坏的情况下)

我们知道,对于足够大的输入,循环遍历行在NumPy中更快,在纯Python中更慢。

作为参考,让我们考虑这两种方法:

代码语言:javascript
复制
# pure Python approach
def all_in_by_row_flt(arr, elems=ELEMS):
    return sum(1 for row in arr if all(e in row for e in elems))

# NumPy apprach (based on @Mstaino answer)
def all_in_by_row_np(arr, elems=ELEMS):
    def _aaa_helper(row, e=elems):
        return np.isin(e, row)
    return np.sum(np.all(np.apply_along_axis(_aaa_helper, 1, arr), 1))

然后,考虑子集检查操作,如果输入在较少的循环中执行检查,则纯Python循环比NumPy更快。相反,如果需要足够多的循环,那么NumPy实际上可以更快。在此之上,还有遍历各行的循环,但是因为子集检查操作是二次的,并且具有不同的常量系数,所以在某些情况下,尽管行循环在NumPy中更快(因为行数将足够大),但是在纯Python中整体操作更快。这就是我在早期的基准测试中遇到的情况,并且对应于子集检查始终(或几乎)为False,并且在几个循环中确实失败的情况。一旦子集检查开始需要更多的循环,Python only方法就开始落后,对于大多数(如果不是全部)行的子集检查实际上是True的情况,NumPy方法实际上更快。

NumPy和纯Python方法之间的另一个关键区别是,纯Python使用延迟计算,而NumPy不使用,并且实际上需要创建潜在的大型中间对象来减慢计算速度。最重要的是,NumPy遍历了两次行(一次在sum()中,一次在np.apply_along_axis()中),而纯Python只迭代一次。

其他使用set().issubset()的方法,例如来自@GZ0的答案:

代码语言:javascript
复制
def all_in_by_row_set(arr, elems=ELEMS):
    elems = set(elems)
    return sum(map(elems.issubset, row))

当涉及到子集检查时,与显式编写嵌套循环相比,它们具有不同的时序,但它们仍然受到较慢的外部循环的影响。

那么,下一步呢?

答案是使用CythonNumba。我们的想法是获得类NumPy(读: C)所有时间的速度(不仅仅是足够大的输入),惰性计算和最少的行循环次数。

Cython方法的一个示例(使用%load_ext Cython魔术在IPython中实现)是:

代码语言:javascript
复制
%%cython --cplus -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True


cdef long all_in_by_row_c(long[:, :] arr, long[:] elems) nogil:
    cdef long result = 0
    I = arr.shape[0]
    J = arr.shape[1]
    K = elems.shape[0]
    for i in range(I):
        is_subset = True
        for k in range(K):
            is_contained = False
            for j in range(J):
                if elems[k] == arr[i, j]:
                    is_contained = True
                    break
            if not is_contained:
                is_subset = False
                break
        result += 1 if is_subset else 0
    return result


def all_in_by_row_cy(long[:, :] arr, long[:] elems):
    return all_in_by_row_c(arr, elems)

而类似的Numba代码为:

代码语言:javascript
复制
import numba as nb


@nb.jit(nopython=True, nogil=True)
def all_in_by_row_jit(arr, elems=ELEMS):
    result = 0
    n_rows, n_cols = arr.shape
    for i in range(n_rows):
        is_subset = True
        for e in elems:
            is_contained = False
            for r in arr[i, :]:
                if e == r:
                    is_contained = True
                    break
            if not is_contained:
                is_subset = False
                break
        result += 1 if is_subset else 0
    return result

现在,从时间上讲,我们得到以下结果(对于相对较少的行数):

代码语言:javascript
复制
arr.shape=(100, 1000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy  120 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_jit 129 µs ± 131 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_flt 2.44 ms ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_set 9.98 ms ± 52.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_np  13.7 ms ± 52.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

arr.shape=(100, 2000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy  1.45 ms ± 24.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_jit 1.52 ms ± 4.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_flt 30.1 ms ± 452 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_set 19.8 ms ± 56.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_np  18 ms ± 28.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

arr.shape=(100, 3000) elems.shape=(1000,) result=37
Func: all_in_by_row_cy  10.4 ms ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_jit 10.9 ms ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_flt 226 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 30.5 ms ± 92.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np  21.9 ms ± 87.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

arr.shape=(100, 4000) elems.shape=(1000,) result=86
Func: all_in_by_row_cy  16.8 ms ± 32.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_jit 17.7 ms ± 42 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Func: all_in_by_row_flt 385 ms ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 39.5 ms ± 588 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np  25.7 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

现在,最后一个块的速度变慢不能用第二维中增加的输入大小来解释。实际上,如果短路率增加了(例如,通过改变随机数组的值范围),那么对于最后一个块(相同的输入大小),会得到:

代码语言:javascript
复制
arr.shape=(100, 4000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy   152 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_jit  173 µs ± 4.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Func: all_in_by_row_flt  556 µs ± 8.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Func: all_in_by_row_set  39.7 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_np   31.5 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

请注意,set()-based方法在某种程度上独立于短路率(因为基于散列的实现具有存在复杂性的~O(1)检查,但这是以散列预计算的费用为代价的,这些结果表明这可能不会比直接嵌套循环方法更快)。

最后,对于更大的行数:

代码语言:javascript
复制
arr.shape=(100000, 1000) elems.shape=(1000,) result=0
Func: all_in_by_row_cy  141 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_jit 150 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Func: all_in_by_row_flt 2.6 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 10.1 s ± 216 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np  13.7 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

arr.shape=(100000, 2000) elems.shape=(1000,) result=34
Func: all_in_by_row_cy  1.2 s ± 753 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 1.27 s ± 7.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 24.1 s ± 119 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 19.5 s ± 270 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np  18 s ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

arr.shape=(100000, 3000) elems.shape=(1000,) result=33859
Func: all_in_by_row_cy  9.79 s ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 10.3 s ± 5.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 3min 30s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 30 s ± 57.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np  21.9 s ± 59.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

arr.shape=(100000, 4000) elems.shape=(1000,) result=86376
Func: all_in_by_row_cy  17 s ± 30.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_jit 17.9 s ± 13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_flt 6min 29s ± 293 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_set 38.9 s ± 33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Func: all_in_by_row_np  25.7 s ± 29.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

最后,请注意,Cython/Numba代码可以在算法上进行优化。

票数 0
EN

Stack Overflow用户

发布于 2019-07-31 00:19:26

可能有一种更有效的解决方案,但是如果您想要包含test_elements的“所有”元素的行,您可以反转np.isin并将其应用于每一行,如下所示:

代码语言:javascript
复制
np.apply_along_axis(lambda x: np.isin(test_elements, x), 1, elements).all(1).sum()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57275404

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档