所以我想在numba jit的帮助下加速我写的程序。然而,jit似乎不兼容许多枕函数,因为它们使用try . except .jit无法处理的结构(这一点我说得对吗?)
我想出的一个相对简单的解决方案是复制我需要的枕源代码并删除try except部件(我已经知道它不会遇到错误,所以try部件无论如何都会正常工作)。
然而,我不喜欢这个解决办法,我也不确定它是否会奏效。
我的代码结构如下
import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit
def fitfunction():
...
@jit
def function(x):
# do some stuff
try:
fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=(0,0,0), maxfev=500)
for idx in some_list:
integrated = integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
except:
fit_param=(0,0,0)
...现在,这会导致以下错误:
LoweringError:对象失败(对象模式后端)
我认为这是由于jit无法处理try except (如果我只将jit放在curve_fit和integrate.quad部件上并围绕自己的try except结构工作的话,它也不起作用)。
import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit
def fitfunction():
...
@jit
def integral(lower, upper):
return integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
@jit
def fitting(x, y, pzero, max_fev)
return curve_fit(fitfunction, x, y, p0=pzero, maxfev=max_fev)
def function(x):
# do some stuff
try:
fit_param, fit_cov = fitting(x, y, (0,0,0), 500)
for idx in some_list:
integrated = integral(lower, upper)
except:
fit_param=(0,0,0)
...是否有一种方法可以将jit与scipy.integrate.quad和curve_fit一起使用,而无需手动删除枕代码中的所有try except结构?
它还能加速密码吗?
发布于 2019-03-24 17:46:18
Numba只不过是,而不是,它是一个通用的库,用来加速代码的编写。有一类问题可以用numba以更快的方式解决(特别是在数组上有循环、数字处理的情况下),但其他的问题要么(1)不支持,要么(2)只稍微快一点,甚至慢得多。
..。它会加速密码吗?
SciPy已经是一个高性能的库,所以在大多数情况下,我希望numba执行得更差(或者很少:稍微好一点)。您可以做一些侧写来了解瓶颈是否真的存在于您jit泰德的代码中,然后您可以得到一些改进。但我怀疑瓶颈将出现在SciPy的编译代码中,而且编译后的代码可能已经得到了很大的优化(所以确实不太可能找到一个只能与该代码“竞争”的实现)。
是否有一种方法可以与scipy.integrate.quad和curve_fit一起使用jit,而无需手动删除除jit代码中的结构之外的所有try?
正如您正确地假设的那样,numba现在根本不支持try和except。
2.6.1.语言 2.6.1.1.构式 Numba尽力支持尽可能多的Python语言,但是Numba编译的函数中没有一些语言特性。目前不支持以下Python语言特性: ..。
try ..except,try .finally)所以这里的答案是No。
发布于 2021-07-26 06:21:27
现在,try和except和numba一起工作。然而,农巴和枕叶仍然不相容。是的,Scipy调用编译了C和Fortran,但它是以numba无法处理的方式进行的。
幸运的是,还有其他的选择,可以很好地应用于农巴!下面,我使用NumbaQuadpack和NumbaMinpack进行一些曲线拟合和集成,类似于您的示例代码。免责声明:我把这些包裹放在一起。在下面,我也给出了一个等价的执行方案。
的实现速度是 (NumbaQuadpack和NumbaMinpack)的18倍。
使用Scipy替代品(0.23ms)
from NumbaQuadpack import quadpack_sig, dqags
from NumbaMinpack import minpack_sig, lmdif
import numpy as np
import numba as nb
import timeit
np.random.seed(0)
x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)
@nb.jit
def fitfunction(x, A, B):
return A*np.sin(B*x)
@nb.cfunc(minpack_sig)
def fitfunction_optimize(u_, fvec, args_):
u = nb.carray(u_,(2,))
args = nb.carray(args_,(200,))
A, B = u
x = args[:100]
y = args[100:]
for i in range(100):
fvec[i] = fitfunction(x[i], A, B) - y[i]
optimize_ptr = fitfunction_optimize.address
@nb.cfunc(quadpack_sig)
def fitfunction_integrate(x, data):
A = data[0]
B = data[1]
return fitfunction(x, A, B)
integrate_ptr = fitfunction_integrate.address
@nb.njit
def fast_function():
try:
neqs = 100
u_init = np.array([2.0,.8],np.float64)
args = np.append(x,y)
fitparam, fvec, success, info = lmdif(optimize_ptr , u_init, neqs, args)
if not success:
raise Exception
lower = 0.0
uppers = np.linspace(np.pi,np.pi*2.0,200)
solutions = np.empty(len(uppers))
for i in range(len(uppers)):
solutions[i], abserr, success = dqags(integrate_ptr, lower, uppers[i], data = fitparam)
if not success:
raise Exception
except:
print('doing something else')
fast_function()
iters = 1000
t_nb = timeit.Timer(fast_function).timeit(number=iters)/iters
print(t_nb)使用Scipy (4.4毫秒)
import scipy.integrate as integrate
from scipy.optimize import curve_fit
import numpy as np
import numba as nb
import timeit
np.random.seed(0)
x = np.linspace(0,2*np.pi,100)
y = np.sin(x)+ np.random.rand(100)
@nb.jit
def fitfunction(x, A, B):
return A*np.sin(B*x)
def function():
try:
p0 = (2.0,.8)
fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=p0, maxfev=500)
lower = 0.0
uppers = np.linspace(np.pi,np.pi*2.0,200)
solutions = np.empty(len(uppers))
for i in range(len(uppers)):
solutions[i], abserr = integrate.quad(fitfunction, lower, uppers[i], args = tuple(fit_param))
except:
print('do something else')
function()
iters = 1000
t_sp = timeit.Timer(function).timeit(number=iters)/iters
print(t_sp)https://stackoverflow.com/questions/55317665
复制相似问题