grid-construction?矢量化的最佳方法是什么?
In [1]: import numpy as np
In [2]: mesh = np.linspace(-1, 1, 3000)
In [3]: rowwise, colwise = np.meshgrid(mesh, mesh)
In [4]: f = lambda x, y: np.where(x > y, x**2, x**3)
# Using 2D arrays:
In [5]: %timeit f(colwise, rowwise)
285 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Using 1D array and list-comprehension:
In [6]: %timeit np.array([f(x, mesh) for x in mesh])
58 ms ± 2.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Equivalent result
In [7]: np.allclose(f(colwise, rowwise), np.array([f(x, mesh) for x in mesh]))
True发布于 2021-12-07 17:43:13
In [1]: In [2]: mesh = np.linspace(-1, 1, 3000)
...: In [3]: rowwise, colwise = np.meshgrid(mesh, mesh)
...: In [4]: f = lambda x, y: np.where(x > y, x**2, x**3)此外,让我们创建稀疏网格:
In [2]: r1,c1 = np.meshgrid(mesh,mesh,sparse=True)
In [3]: rowwise.shape
Out[3]: (3000, 3000)
In [4]: r1.shape
Out[4]: (1, 3000)对于稀疏网格,时间甚至比迭代更好:
In [5]: timeit f(colwise, rowwise)
645 ms ± 57.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [6]: timeit f(c1,r1)
108 ms ± 3.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [7]: timeit np.array([f(x, mesh) for x in mesh])
166 ms ± 13.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)另一个答案强调缓存。其他文章已经表明,与使用非常大的数组(例如在使用matmul时)相比,少量的迭代可以更快。我不知道是否是缓存或其他内存管理的复杂性减缓了这一点。
但在3000*3000*8字节中,我不确定这就是问题所在。相反,我认为现在是x**2和x**3表达式需要的时候了。
where的参数在传入之前进行计算。
条件表达式所需的时间不多:
In [8]: timeit colwise>rowwise
24.2 ms ± 71.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)但是(3000,3000)数组的功率表达式占总时间的绝大部分:
In [9]: timeit rowwise**3
467 ms ± 8.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)将此与稀疏等效所需时间进行对比:
In [10]: timeit r1**3
142 µs ± 150 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)这一次速度快了3288倍,这比O(n)缩放要糟糕一些。
重复乘法更好:
In [11]: timeit rowwise*rowwise*rowwise
116 ms ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)在[f(x, mesh) for x in mesh]中,x**3在标量上运行,所以速度很快,尽管它重复了3000次。
实际上,如果我们将功率计算从计时中剔除,整个阵列where的速度相对较快:
In [15]: %%timeit x2,x3 = rowwise**2, rowwise**3
...: np.where(rowwise>colwise, x2,x3)
89.8 ms ± 3.99 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)https://stackoverflow.com/questions/70255503
复制相似问题