首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >努巴jit伴枕

努巴jit伴枕
EN

Stack Overflow用户
提问于 2019-03-23 19:42:46
回答 2查看 14.2K关注 0票数 11

所以我想在numba jit的帮助下加速我写的程序。然而,jit似乎不兼容许多枕函数,因为它们使用try . except .jit无法处理的结构(这一点我说得对吗?)

我想出的一个相对简单的解决方案是复制我需要的枕源代码并删除try except部件(我已经知道它不会遇到错误,所以try部件无论如何都会正常工作)。

然而,我不喜欢这个解决办法,我也不确定它是否会奏效。

我的代码结构如下

代码语言:javascript
复制
import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=(0,0,0), maxfev=500)
        for idx in some_list:
            integrated = integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
    except:
        fit_param=(0,0,0)
        ...

现在,这会导致以下错误:

LoweringError:对象失败(对象模式后端)

我认为这是由于jit无法处理try except (如果我只将jit放在curve_fitintegrate.quad部件上并围绕自己的try except结构工作的话,它也不起作用)。

代码语言:javascript
复制
import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def integral(lower, upper):
    return integrate.quad(lambda x: fitfunction(fit_param), lower, upper)

@jit
def fitting(x, y, pzero, max_fev)
    return curve_fit(fitfunction, x, y, p0=pzero, maxfev=max_fev)


def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = fitting(x, y, (0,0,0), 500)
        for idx in some_list:
            integrated = integral(lower, upper)
    except:
        fit_param=(0,0,0)
        ...

是否有一种方法可以将jitscipy.integrate.quadcurve_fit一起使用,而无需手动删除枕代码中的所有try except结构?

它还能加速密码吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-03-24 17:46:18

Numba只不过是,而不是,它是一个通用的库,用来加速代码的编写。有一类问题可以用numba以更快的方式解决(特别是在数组上有循环、数字处理的情况下),但其他的问题要么(1)不支持,要么(2)只稍微快一点,甚至慢得多。

..。它会加速密码吗?

SciPy已经是一个高性能的库,所以在大多数情况下,我希望numba执行得更差(或者很少:稍微好一点)。您可以做一些侧写来了解瓶颈是否真的存在于您jit泰德的代码中,然后您可以得到一些改进。但我怀疑瓶颈将出现在SciPy的编译代码中,而且编译后的代码可能已经得到了很大的优化(所以确实不太可能找到一个只能与该代码“竞争”的实现)。

是否有一种方法可以与scipy.integrate.quad和curve_fit一起使用jit,而无需手动删除除jit代码中的结构之外的所有try?

正如您正确地假设的那样,numba现在根本不支持tryexcept

2.6.1.语言 2.6.1.1.构式 Numba尽力支持尽可能多的Python语言,但是Numba编译的函数中没有一些语言特性。目前不支持以下Python语言特性: ..。

  • 异常处理(try ..excepttry .finally)

所以这里的答案是No。

票数 12
EN

Stack Overflow用户

发布于 2021-07-26 06:21:27

现在,tryexcept和numba一起工作。然而,农巴和枕叶仍然不相容。是的,Scipy调用编译了C和Fortran,但它是以numba无法处理的方式进行的。

幸运的是,还有其他的选择,可以很好地应用于农巴!下面,我使用NumbaQuadpackNumbaMinpack进行一些曲线拟合和集成,类似于您的示例代码。免责声明:我把这些包裹放在一起。在下面,我也给出了一个等价的执行方案。

的实现速度是 (NumbaQuadpack和NumbaMinpack)的18倍。

使用Scipy替代品(0.23ms)

代码语言:javascript
复制
from NumbaQuadpack import quadpack_sig, dqags
from NumbaMinpack import minpack_sig, lmdif
import numpy as np
import numba as nb
import timeit
np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

@nb.cfunc(minpack_sig)
def fitfunction_optimize(u_, fvec, args_):
    u = nb.carray(u_,(2,))
    args = nb.carray(args_,(200,))
    A, B = u
    x = args[:100]
    y = args[100:]
    for i in range(100):
        fvec[i] = fitfunction(x[i], A, B) - y[i] 
optimize_ptr = fitfunction_optimize.address

@nb.cfunc(quadpack_sig)
def fitfunction_integrate(x, data):
    A = data[0]
    B = data[1]
    return fitfunction(x, A, B)
integrate_ptr = fitfunction_integrate.address

@nb.njit
def fast_function():  
    try:
        neqs = 100
        u_init = np.array([2.0,.8],np.float64)
        args = np.append(x,y)
        fitparam, fvec, success, info = lmdif(optimize_ptr , u_init, neqs, args)
        if not success:
            raise Exception

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr, success = dqags(integrate_ptr, lower, uppers[i], data = fitparam)
            if not success:
                raise Exception
    except:
        print('doing something else')
        
fast_function()
iters = 1000
t_nb = timeit.Timer(fast_function).timeit(number=iters)/iters
print(t_nb)

使用Scipy (4.4毫秒)

代码语言:javascript
复制
import scipy.integrate as integrate
from scipy.optimize import curve_fit
import numpy as np
import numba as nb
import timeit

np.random.seed(0)

x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)

@nb.jit
def fitfunction(x, A, B):
    return A*np.sin(B*x)

def function():
    try:
        p0 = (2.0,.8)
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=p0, maxfev=500)

        lower = 0.0
        uppers = np.linspace(np.pi,np.pi*2.0,200)
        solutions = np.empty(len(uppers))
        for i in range(len(uppers)):
            solutions[i], abserr = integrate.quad(fitfunction, lower, uppers[i], args = tuple(fit_param))
    except:
        print('do something else')

function()
iters = 1000
t_sp = timeit.Timer(function).timeit(number=iters)/iters
print(t_sp)
票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55317665

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档