首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >向矢量化大NumPy乘法

向矢量化大NumPy乘法
EN

Stack Overflow用户
提问于 2015-10-25 20:19:12
回答 2查看 444关注 0票数 5

我感兴趣的是计算一个大型NumPy数组。我有一个大数组A,它包含一串数字。我要计算这些数字的不同组合的和。这些数据的结构如下:

代码语言:javascript
复制
A = np.random.uniform(0,1, (3743, 1388, 3))
Combinations = np.random.randint(0,3, (306,3))
Final_Product = np.array([  np.sum( A*cb, axis=2)  for cb in Combinations])

我的问题是,是否有一种更优雅和内存更有效的方法来计算这一点?当涉及到三维数组时,我发现使用np.dot()是件令人沮丧的事情.

如果有帮助,Final_Product的理想形状应该是(3743,306,1388)。目前Final_Product是形状(306,3743,1388),所以我可以只是重塑到那里。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2015-10-25 20:24:58

np.dot()不会给出所需的输出,除非涉及可能包含reshaping的额外步骤。这里有一种vectorized方法,使用np.einsum来完成它,而不需要任何额外的内存开销-

代码语言:javascript
复制
Final_Product = np.einsum('ijk,lk->lij',A,Combinations)

为了完整起见,下面是前面讨论过的np.dotreshaping -

代码语言:javascript
复制
M,N,R = A.shape
Final_Product = A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)

运行时测试和验证输出-

代码语言:javascript
复制
In [138]: # Inputs ( smaller version of those listed in question )
     ...: A = np.random.uniform(0,1, (374, 138, 3))
     ...: Combinations = np.random.randint(0,3, (30,3))
     ...: 

In [139]: %timeit np.array([  np.sum( A*cb, axis=2)  for cb in Combinations])
1 loops, best of 3: 324 ms per loop

In [140]: %timeit np.einsum('ijk,lk->lij',A,Combinations)
10 loops, best of 3: 32 ms per loop

In [141]: M,N,R = A.shape

In [142]: %timeit A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
100 loops, best of 3: 15.6 ms per loop

In [143]: Final_Product =np.array([np.sum( A*cb, axis=2)  for cb in Combinations])
     ...: Final_Product2 = np.einsum('ijk,lk->lij',A,Combinations)
     ...: M,N,R = A.shape
     ...: Final_Product3 = A.reshape(-1,R).dot(Combinations.T).T.reshape(-1,M,N)
     ...: 

In [144]: print np.allclose(Final_Product,Final_Product2)
True

In [145]: print np.allclose(Final_Product,Final_Product3)
True
票数 5
EN

Stack Overflow用户

发布于 2015-10-25 20:39:06

而不是dot,您可以使用tensordot。您的当前方法相当于:

代码语言:javascript
复制
np.tensordot(A, Combinations, [2, 1]).transpose(2, 0, 1)

请注意末尾的transpose以正确的顺序放置轴。

dot类似,tensordot函数可以调用快速的BLAS/LAPACK库(如果已经安装了它们),因此对于大型数组应该执行得很好。

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/33334630

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档