首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Numba njit编译器导致与普通Python代码相比计算不同的数字?

Numba njit编译器导致与普通Python代码相比计算不同的数字?
EN

Stack Overflow用户
提问于 2021-11-18 17:10:40
回答 1查看 71关注 0票数 1

我在Python中使用Numba的njit工具时遇到了一个问题。我注意到,该函数在使用@numba.njit运行和以普通Python代码运行时会产生不同的结果。特别是,在调试之后,我注意到在使用numpy执行矩阵求逆时,计算中会出现差异。请参阅下面的测试代码。矩阵A和向量b的值位于以下csv文件中,可以通过以下链接访问这些文件:A.csvb.csv

普通Python函数的结果是正确的。请帮我解决这个问题!我是否需要在numpy矩阵求逆函数周围使用Numba包装函数来解决似乎是一个数值问题?

好心的毕业生们,我期待着很快收到你们的回复:)

艾哈迈德

代码语言:javascript
复制
@numba.njit
def cal_Test_jit(A,b):
    c = np.linalg.inv(A)@b
    return c, np.linalg.inv(A)

def cal_Test(A,b):
    c = np.linalg.inv(A)@b
    return c, np.linalg.inv(A)

A = np.loadtxt(open("A.csv", "rb"), delimiter=",")
b = np.loadtxt(open("b.csv", "rb"), delimiter=",")

c_jit, Ai_jit = cal_Test_jit(A,b)
c, Ai = cal_Test(A,b)
err_c = abs(c-c_jit)
err_A = abs(Ai_jit-Ai)

# ploting the error in the parameters
plt.figure()
plt.plot(err_c)

# only ploting the error in first three columns of A
fig, ax = plt.subplots(1,3)
ax[0].plot(err_A[:,0])
ax[1].plot(err_A[:,1])
ax[2].plot(err_A[:,2])
EN

回答 1

Stack Overflow用户

发布于 2021-11-18 21:46:37

在您的问题中使用numba的一种方法是添加:

代码语言:javascript
复制
@numba.jit(forceobj=True)

它将获得真实的结果,但执行时间比较为using njit (different results) > this method (exact results) == plain Python,例如使用colab TPU:

代码语言:javascript
复制
1000 loops, best of 5: 545 µs per loop     # using njit
1000 loops, best of 5: 505 µs per loop     # this method
1000 loops, best of 5: 500 µs per loop     # plain Python

但正如之前在另一个SO question上推荐的那样,Numba对于优化纯Python的子集非常有用,尤其是循环,接近优化的C代码@BatWannaBe的性能,而且几乎总是有比求逆矩阵更好的方法,例如np.linalg.solve @Humer512;这在另一个SO question @ali_m上有很好的解释。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70023988

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档