我正在做以下计算:
y = np.mat(np.log(datay))
x = np.mat([datax**2, datax, np.ones(len(datax))]).T
popt = (x.T * x).I * x.T * y.Tdatax和datay是正常的一维np.arrays,例如:
datay = np.array([1,4,9,16])
datax = np.array([1,2,3,4])计算效果很好。但是我想加快速度,所以我试着把它放进numba:(我是numba...but的新手,想试一试)
@jit(nopython=True)
def calc(datax, datay):
y = np.mat(np.log(datay))
x = np.mat([datax**2, datax, np.ones(len(datax))]).T
return (x.T * x).I * x.T * y.T但这是行不通的。我得到以下错误
在nopython (nopython前端)失败 模块numpy的未知属性“矩阵”
那我怎么能让它起作用呢?
第二件事是:正如您可能注意到的,我正在计算二阶多项式的参数。我需要尽快做到这一点,因为我需要经常这样做。所以现在我只想把所有的
result = np.zeros(len(datay), 3)
datax = np.array([1,2,3,4)]
x = np.matrix([datax**2, datax, np.ones(len(datax))]).T
for i, data in enumerate(datay):
data = np.array(data-baseline)
if (any(i <= 0 for i in data)): continue
try:
y = np.matrix(np.log(data))
result[i] = ((x.T * x).I * x.T * y.T).A1我怎样才能加快速度:只要把所有的内容都放在一个numba函数中,并希望编译就行了?或者还有其他聪明的方法?numba有并行化的工具,对吗?它们能适用于我的情况吗?
谢谢你抽出时间:)
发布于 2018-04-25 13:58:56
numba支持的所有numpy特性都在文档中的以下页面中列出:
http://numba.pydata.org/numba-doc/latest/reference/numpysupported.html
未列出对numpy矩阵对象的支持,因此目前无法在numba jitted程序中使用它们。如果您所做的工作很容易被向量化,并且不涉及创建许多中间数组对象,那么从numba获得的速度可能是有限的。您可以尝试从matrix切换到array数据结构,因为后者是受支持的。矩阵对象只是有一些稍微不同的行为,所以应该很容易转换代码。
https://stackoverflow.com/questions/50023703
复制相似问题