首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >二维数组上的numpy.where比一维数组上numpy.where的列表理解慢

二维数组上的numpy.where比一维数组上numpy.where的列表理解慢
EN

Stack Overflow用户
提问于 2021-12-07 05:52:32
回答 1查看 130关注 0票数 1
  1. 为什么在这种情况下,Numpy比列表理解慢?

grid-construction?矢量化的最佳方法是什么?

代码语言:javascript
复制
In [1]: import numpy as np

In [2]: mesh = np.linspace(-1, 1, 3000)

In [3]: rowwise, colwise = np.meshgrid(mesh, mesh)

In [4]: f = lambda x, y: np.where(x > y, x**2, x**3)

# Using 2D arrays:
In [5]: %timeit f(colwise, rowwise)
285 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Using 1D array and list-comprehension:
In [6]: %timeit np.array([f(x, mesh) for x in mesh])
58 ms ± 2.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# Equivalent result
In [7]: np.allclose(f(colwise, rowwise), np.array([f(x, mesh) for x in mesh]))
True
EN

回答 1

Stack Overflow用户

发布于 2021-12-07 17:43:13

代码语言:javascript
复制
In [1]: In [2]: mesh = np.linspace(-1, 1, 3000)
   ...: In [3]: rowwise, colwise = np.meshgrid(mesh, mesh)
   ...: In [4]: f = lambda x, y: np.where(x > y, x**2, x**3)

此外,让我们创建稀疏网格:

代码语言:javascript
复制
In [2]: r1,c1 = np.meshgrid(mesh,mesh,sparse=True)
In [3]: rowwise.shape
Out[3]: (3000, 3000)
In [4]: r1.shape
Out[4]: (1, 3000)

对于稀疏网格,时间甚至比迭代更好:

代码语言:javascript
复制
In [5]: timeit f(colwise, rowwise)
645 ms ± 57.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [6]: timeit f(c1,r1)
108 ms ± 3.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [7]: timeit np.array([f(x, mesh) for x in mesh])
166 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

另一个答案强调缓存。其他文章已经表明,与使用非常大的数组(例如在使用matmul时)相比,少量的迭代可以更快。我不知道是否是缓存或其他内存管理的复杂性减缓了这一点。

但在3000*3000*8字节中,我不确定这就是问题所在。相反,我认为现在是x**2x**3表达式需要的时候了。

where的参数在传入之前进行计算。

条件表达式所需的时间不多:

代码语言:javascript
复制
In [8]: timeit colwise>rowwise
24.2 ms ± 71.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

但是(3000,3000)数组的功率表达式占总时间的绝大部分:

代码语言:javascript
复制
In [9]: timeit rowwise**3
467 ms ± 8.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

将此与稀疏等效所需时间进行对比:

代码语言:javascript
复制
In [10]: timeit r1**3
142 µs ± 150 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

这一次速度快了3288倍,这比O(n)缩放要糟糕一些。

重复乘法更好:

代码语言:javascript
复制
In [11]: timeit rowwise*rowwise*rowwise
116 ms ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

[f(x, mesh) for x in mesh]中,x**3在标量上运行,所以速度很快,尽管它重复了3000次。

实际上,如果我们将功率计算从计时中剔除,整个阵列where的速度相对较快:

代码语言:javascript
复制
In [15]: %%timeit x2,x3 = rowwise**2, rowwise**3
    ...: np.where(rowwise>colwise, x2,x3)
89.8 ms ± 3.99 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70255503

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档