首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么NumPy矩阵在一个方向上工作而不是在转置方向上工作?

为什么NumPy矩阵在一个方向上工作而不是在转置方向上工作?
EN

Stack Overflow用户
提问于 2019-10-01 20:28:04
回答 1查看 171关注 0票数 0

考虑两个数组之间的矩阵乘积:

代码语言:javascript
复制
import numpy as np
A = np.random.rand(2,10,10)                                             
B = np.random.rand(2,2)                                                 
C = A.T @ B

...goes很好。我认为以上是一个1乘2乘以2乘2的矢量矩阵产品,在A的10乘10,2,2和3维上广播,结果的检验C证实了这个直觉,np.allclose(C[i,j], A.T[i,j] @ B)代表了所有的ij

现在,从数学上讲,我应该能够计算C.T以及:B.T @ A,但是:

代码语言:javascript
复制
B.T @ A                                                                
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-32-ffdbb14ca160> in <module>
----> 1 B.T @ A

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 10 is different from 2)

所以就广播而言,10乘10乘2张量和2乘2矩阵与矩阵乘积是相容的,但是2乘2矩阵和2乘10乘10张量不是吗?

额外信息:,我希望能够计算出“二次乘积”A.T @ B @ A,它真的让我很恼火,我不得不在一个维度上手动地“广播”。我觉得这样做应该是可能的。我对Python和NumPy相当有经验,但我很少超越二维数组。

我在这里错过了什么?关于转置对NumPy中的张量的操作方式,我是否还不明白呢?

EN

回答 1

Stack Overflow用户

发布于 2019-10-01 21:39:32

代码语言:javascript
复制
In [194]: A = np.random.rand(2,10,10)                                           
     ...:    
     ...: B = np.random.rand(2,2)                                               
In [196]: A.T.shape                                                             
Out[196]: (10, 10, 2)

In [197]: C = A.T @ B                                                           
In [198]: C.shape                                                               
Out[198]: (10, 10, 2)

einsum的等效值是:

代码语言:javascript
复制
In [199]: np.allclose(np.einsum('ijk,kl->ijl',A.T,B),C)                         
Out[199]: True

或在索引中加入转置:

代码语言:javascript
复制
In [200]: np.allclose(np.einsum('kji,kl->ijl',A,B),C)                           
Out[200]: True

请注意,k是相加的维度。jldot的其他维度。i是一种批处理维数。

或者就像你解释的那样np.einsum('k,kl->l', A.T[i,j], B)

要获得C.Teinsum结果索引应该是ljilk,jki->lji

代码语言:javascript
复制
In [201]: np.allclose(np.einsum('lk,jki->lji', B.T, A.transpose(1,0,2)), C.T)      
Out[201]: True

In [226]: np.allclose(np.einsum('ij,jkl->ikl', B.T, A), C.T)                       
Out[226]: True

@匹配201需要进一步转置:

代码语言:javascript
复制
In [225]: np.allclose((B.T@(A.transpose(1,0,2))).transpose(1,0,2), C.T)          
Out[225]: True

使用einsum时,可以将轴按任何顺序放置,但是对于matmul,顺序是固定的(batch, i, k)@(batch, k, l) -> (batch, i, l) ( batch维度可以广播)。

如果A具有形状(2,10,9)和B (2,3),而C生成(9,10,3),则示例可能更容易一些。

代码语言:javascript
复制
In [229]: A = np.random.rand(2,10,9); B = np.random.rand(2,3)                   
In [230]: C = A.T @ B                                                           
In [231]: C.shape                                                               
Out[231]: (9, 10, 3)
In [232]: C.T.shape                                                             
Out[232]: (3, 10, 9)

In [234]: ((B.T) @ (A.transpose(1,0,2))).shape                                    
Out[234]: (10, 3, 9)
In [235]: ((B.T) @ (A.transpose(1,0,2))).transpose(1,0,2).shape                   
Out[235]: (3, 10, 9)
In [236]: np.allclose(((B.T) @ (A.transpose(1,0,2))).transpose(1,0,2), C.T)        
Out[236]: True
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58191834

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档