首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >NumPy中的自定义非线性矩阵乘法

NumPy中的自定义非线性矩阵乘法
EN

Stack Overflow用户
提问于 2022-06-22 16:39:01
回答 1查看 66关注 0票数 0

假设我必须得到矩阵UW

代码语言:javascript
复制
U = np.arange(6*2).reshape((6,2))
W = np.arange(5*2).reshape((5,2))

对于一个标准的线性乘法,我可以:

代码语言:javascript
复制
U @ W.T
代码语言:javascript
复制
array([[  1,   3,   5,   7,   9],
       [  3,  13,  23,  33,  43],
       [  5,  23,  41,  59,  77],
       [  7,  33,  59,  85, 111],
       [  9,  43,  77, 111, 145],
       [ 11,  53,  95, 137, 179]])

但是,我也可以(技术上)定义一个线性乘法函数,在一个for-循环中按列和和:

代码语言:javascript
复制
def mult(U, W, i):
  return U[:, [i]] @ W.T[[i],:]

sum([mult(U, W, i) for i in range(2)]) #1
代码语言:javascript
复制
array([[  1,   3,   5,   7,   9],
       [  3,  13,  23,  33,  43],
       [  5,  23,  41,  59,  77],
       [  7,  33,  59,  85, 111],
       [  9,  43,  77, 111, 145],
       [ 11,  53,  95, 137, 179]])

现在假设mult()不再是线性的,它是非线性的,定制的,例如:

代码语言:javascript
复制
def mult(U, W, i):
  return (U[:, [i]] @ W.T[[i],:]) * np.cos(U[:, [i]] @ W.T[[i],:])

sum([mult(U, W, i) for i in range(2)]) #2

您可以验证这与(U @ W.T) * np.cos(U @ W.T)不完全相同。但是我想知道是否有一种更紧凑的方式来编写#2,就像如果#1是线性的话,有一种更紧凑的方式来编写mult()。效率很好,但我不是在处理巨大的矩阵。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-22 17:58:29

@,和np.dot一样,是一种矩阵乘法,涉及到我们通常所说的乘积之和.这是一个基本的线性代数操作,np.matmul使用高效的编译库来完成这个操作(在可能的情况下)。

您的sum([mult(...))正在这样做--获取行/列产品并对它们进行求和。编译后的代码可能使用更高效的方法,这些方法在迭代的cFortran中很好地工作。

您的mult函数可以使用广播元素方向乘法。对于一个i

代码语言:javascript
复制
In [43]: i=1;U[:, [i]] @ W.T[[i],:]     # (6,1) @ (1,5) => (6,5)
Out[43]: 
array([[ 1,  3,  5,  7,  9],
       [ 3,  9, 15, 21, 27],
       [ 5, 15, 25, 35, 45],
       [ 7, 21, 35, 49, 63],
       [ 9, 27, 45, 63, 81],
       [11, 33, 55, 77, 99]])

In [44]: i=1;U[:, [i]] * W.T[[i],:]
Out[44]: 
array([[ 1,  3,  5,  7,  9],
       [ 3,  9, 15, 21, 27],
       [ 5, 15, 25, 35, 45],
       [ 7, 21, 35, 49, 63],
       [ 9, 27, 45, 63, 81],
       [11, 33, 55, 77, 99]])

如果没有清单理解,这可以写成:

代码语言:javascript
复制
In [46]: (U[:,None,:]*W[None,:,:]).shape
Out[46]: (6, 5, 2)

In [47]: (U[:,None,:]*W[None,:,:]).sum(axis=2)
Out[47]: 
array([[  1,   3,   5,   7,   9],
       [  3,  13,  23,  33,  43],
       [  5,  23,  41,  59,  77],
       [  7,  33,  59,  85, 111],
       [  9,  43,  77, 111, 145],
       [ 11,  53,  95, 137, 179]])

至于您使用`np.cos的版本:

代码语言:javascript
复制
In [48]: def mult(U, W, i):
    ...:   return (U[:, [i]] @ W.T[[i],:]) * np.cos(U[:, [i]] @ W.T[[i],:])
    ...: sum([mult(U, W, i) for i in range(2)]) #2
Out[48]: 
array([[ 5.40302306e-01, -2.96997749e+00,  1.41831093e+00,
         5.27731578e+00, -8.20017236e+00],
       [-2.96997749e+00, -1.08147468e+01, -1.25593190e+01,
        -1.37606696e+00, -2.32102995e+01],
       [ 1.41831093e+00, -1.25593190e+01,  9.45751861e+00,
        -2.14489310e+01,  5.03346370e+01],
       [ 5.27731578e+00, -1.37606696e+00, -2.14489310e+01,
         1.01223418e+01,  3.13845563e+01],
       [-8.20017236e+00, -2.32102995e+01,  5.03346370e+01,
         3.13845563e+01,  8.79904273e+01],
       [ 4.86826779e-02,  7.72350858e+00, -2.54605509e+01,
        -5.95298563e+01, -4.88871235e+00]])

我可以使用相同的外部/和格式:

代码语言:javascript
复制
In [49]: (U[:,None,:]*W[None,:,:]*np.cos(U[:,None,:]*W[None,:,:])).sum(axis=2)
Out[49]: 
array([[ 5.40302306e-01, -2.96997749e+00,  1.41831093e+00,
         5.27731578e+00, -8.20017236e+00],
       [-2.96997749e+00, -1.08147468e+01, -1.25593190e+01,
        -1.37606696e+00, -2.32102995e+01],
       [ 1.41831093e+00, -1.25593190e+01,  9.45751861e+00,
        -2.14489310e+01,  5.03346370e+01],
       [ 5.27731578e+00, -1.37606696e+00, -2.14489310e+01,
         1.01223418e+01,  3.13845563e+01],
       [-8.20017236e+00, -2.32102995e+01,  5.03346370e+01,
         3.13845563e+01,  8.79904273e+01],
       [ 4.86826779e-02,  7.72350858e+00, -2.54605509e+01,
        -5.95298563e+01, -4.88871235e+00]])

由于外部产品被使用了两次,我们可以使用一个临时变量:

代码语言:javascript
复制
In [51]: temp=U[:,None,:]*W[None,:,:]; 
         (temp*np.cos(temp)).sum(axis=2)
Out[51]: 
array([[ 5.40302306e-01, -2.96997749e+00,  1.41831093e+00,
         5.27731578e+00, -8.20017236e+00],
       [-2.96997749e+00, -1.08147468e+01, -1.25593190e+01,
        -1.37606696e+00, -2.32102995e+01],
       [ 1.41831093e+00, -1.25593190e+01,  9.45751861e+00,
        -2.14489310e+01,  5.03346370e+01],
       [ 5.27731578e+00, -1.37606696e+00, -2.14489310e+01,
         1.01223418e+01,  3.13845563e+01],
       [-8.20017236e+00, -2.32102995e+01,  5.03346370e+01,
         3.13845563e+01,  8.79904273e+01],
       [ 4.86826779e-02,  7.72350858e+00, -2.54605509e+01,
        -5.95298563e+01, -4.88871235e+00]])

你不能简单地交换乘法和和步骤,这是一个基本代数问题。

要获得

代码语言:javascript
复制
a1*b1 + a2*b2   

从…

代码语言:javascript
复制
(a1+a2)*(b1+b2) => a1*b1 + a1*b2 + a2*b1 + a2*b2

a1*b2 + a2*b1项必须求和为零,如复数的大小:

代码语言:javascript
复制
In [53]: (1+4j)*(1-4j)
Out[53]: (17+0j)    # (1+16)

一般情况下,产品之和不能转化为总和的乘积。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72719133

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档