首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >加速numpy计算

加速numpy计算
EN

Stack Overflow用户
提问于 2022-05-27 17:11:46
回答 1查看 151关注 0票数 2

我正在编写一个Python函数来整合高维矩阵空间上的向量场。A,shape (n, m),是一个矩阵,其时间导数在其每个分量A[i, j]中是线性的。我们可以将导数的所有系数收集到一个四维阵列C中,使得C[i, j, k, l]A[i, j]导数中的A[k, l]系数。在这种情况下,A的导数由dA[i, j] == (C[i, j] * A).sum()给出。因此,计算是正确的。

代码语言:javascript
复制
dA = np.array([[ (Cij * A).sum() for Cij in Ci ] for Ci in C ])

幸运的是,C可以表示为一个sparse.COO对象,因此上面只需要O(nm)乘法。但这两个for循环仍然很慢。多亏了一条有用的评论,我把这个改进为

代码语言:javascript
复制
dA = (C * A).sum(axis=3).sum(axis=2)

利用广播进行显著的加速。有人能开快点吗?

EN

回答 1

Stack Overflow用户

发布于 2022-05-27 20:06:07

您可以使用np.einsum来加速这一点,因为您不需要做任何中间计算。或者至少您可以执行(C * A).sum(axis=(2,3))来删除一个中间步骤。

代码语言:javascript
复制
import numpy as np
A = np.full((12,12), 2)
C = np.full((12,12,3,2), 1).T
dA = (C * A).sum(axis=3).sum(axis=2)
print(np.einsum('abkl,ijkl->ij', A[None, None], C) == dA)
print((C * A).sum(axis=(2,3)) == dA)

输出:

代码语言:javascript
复制
[[ True  True  True]
 [ True  True  True]]
[[ True  True  True]
 [ True  True  True]]

老实说,我并不完全理解你的数学问题,我也不太擅长einsum。也就是说,您应该反复检查算法和测试用例是否正确:)

编辑:添加.sum(axis=(2,3))方法

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72408913

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档