首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么Numba不改进这个递归函数

为什么Numba不改进这个递归函数
EN

Stack Overflow用户
提问于 2020-06-14 12:17:07
回答 3查看 657关注 0票数 1

我有一个具有非常简单结构的真/假值数组:

代码语言:javascript
复制
# the real array has hundreds of thousands of items
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)

我想遍历这个数组并输出发生更改的位置(true变为false或相反)。为此,我提出了两种不同的方法:

  • --递归二进制搜索(查看是否所有值都是相同的,如果没有,则拆分成两部分)
  • --纯迭代搜索(循环遍历所有元素并与前一个/下一个元素进行比较)

这两个版本都给出了我想要的结果,但是Numba对一个版本的影响比另一个版本更大。对于一个300 k值的虚拟数组,下面是性能结果:

300 k元素数组的

性能结果

binary-search)

  • Numba

  • 纯Python二进制-搜索运行在11 ms

  • 纯Python迭代中搜索运行速度为1.1 s(比二进制文件慢100倍)搜索运行速度为5 ms(比纯
    • 迭代运行快2倍)搜索运行900 s(比纯Python的速度快1,200倍)

因此,当使用Numba时,binary_search比iterative_search慢5倍,而理论上它应该快100倍(如果适当加速,它应该在9 s内运行)。

如何使Numba加速二进制搜索,就像它加速迭代搜索一样?

这两种方法的代码(以及一个示例position数组)都可以在以下公共要点上获得:https://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f

注意: Numba没有在对象模式下运行binary_search(),因为当提到nopython=True__时,它不会抱怨并愉快地编译函数。

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2020-06-14 14:00:04

使用np.diff可以找到值更改的位置,不需要运行更复杂的算法,也不需要使用numba

代码语言:javascript
复制
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
dpos = np.diff(positions)
# array([ True, False, False,  True, False, False, False,  True, False, False])

这是可行的,因为False - True == -1np.bool(-1) == True

它在我的电池供电(=节流模式)和几年前的笔记本电脑上表现很好:

代码语言:javascript
复制
In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)          

In [53]: %timeit np.diff(positions)                                             
633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

我认为用numba编写自己的差异应该会产生类似的性能。

编辑:最后一条语句是假的,我使用numba实现了一个简单的diff函数,它比numpy函数快10倍以上(但它的特性显然也要少得多,但应该足够完成这项任务):

代码语言:javascript
复制
@numba.njit 
def ndiff(x): 
    s = x.size - 1 
    r = np.empty(s, dtype=x.dtype) 
    for i in range(s): 
        r[i] = x[i+1] - x[i] 
    return r

In [68]: np.all(ndiff(positions) == np.diff(positions))                            
Out[68]: True

In [69]: %timeit ndiff(positions)                                               
46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
票数 4
EN

Stack Overflow用户

发布于 2020-06-14 13:45:33

主要的问题是你没有进行苹果与苹果的比较。您提供的不是同一算法的迭代和递归版本。您正在提出两个根本不同的算法,它们恰好是递归/迭代算法。

尤其是在递归方法中,您使用的NumPy内置量要多得多,因此难怪这两种方法之间有如此惊人的差别。Numba JITting在避免NumPy内置时更有效,这也就不足为奇了。最后,递归算法似乎效率较低,因为在np.all()np.any()调用中存在一些隐藏的嵌套循环,而迭代方法是避免的,所以即使您要用Numba更有效地加速使用纯Python编写所有代码,递归方法也要慢一些。

通常,迭代方法比递归等效方法更快,因为它们避免了函数调用开销(与纯Python函数相比,JIT加速函数的调用开销最小)。因此,我建议不要尝试以递归的形式重写算法,结果发现它更慢。

编辑

在一个简单的np.diff()可以发挥作用的前提下,Numba仍然是非常有益的:

代码语言:javascript
复制
import numpy as np
import numba as nb


@nb.jit
def diff(arr):
    n = arr.size
    result = np.empty(n - 1, dtype=arr.dtype)
    for i in range(n - 1):
        result[i] = arr[i + 1] ^ arr[i]
    return result


positions = np.random.randint(0, 2, size=300_000, dtype=bool)
print(np.allclose(np.diff(positions), diff(positions)))
# True


%timeit np.diff(positions)
# 1000 loops, best of 3: 603 µs per loop
%timeit diff(positions)
# 10000 loops, best of 3: 43.3 µs per loop

由于Numba方法的速度大约快了13倍(当然,在这个测试中,里程可能会有所不同)。

票数 3
EN

Stack Overflow用户

发布于 2020-06-14 13:13:17

要点是,只有使用Python机器的逻辑部分才能被加速--用一些等价的C逻辑来替换它,这种逻辑去掉了Python运行时的大部分复杂性(和灵活性)(我猜这就是Numba所做的)。

NumPy操作中的所有繁重工作都已经用C实现了,而且非常简单(因为NumPy数组是保持常规C类型的连续内存块),所以Numba只能剥离与Python接口的部分。

您的“二进制搜索”算法做了更多的工作,同时更多地使用了NumPy的向量运算,因此可以以这种方式加速更少的工作。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62372395

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档