首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何将kronecker产品沿数组尺寸映射?

如何将kronecker产品沿数组尺寸映射?
EN

Stack Overflow用户
提问于 2022-09-10 16:51:51
回答 1查看 66关注 0票数 2

给出了两个维数相同的张量A和B,(d>=2)[A_{1},...,A_{d-2},A_{d-1},A_{d}][A_{1},...,A_{d-2},B_{d-1},B_{d}] (第一维d-2维的形状相同)。

有没有一种方法可以计算最后两个维度上的克朗克积?my_kron(A,B)的形状应为[A_{1},...,A_{d-2},A_{d-1}*B_{d-1},A_{d}*B_{d}]。例如,对于d=3

代码语言:javascript
复制
A.shape=[2,3,3]
B.shape=[2,4,4]
C=my_kron(A,B)

C[0,...]应该是A[0,...]B[0,...]的kronecker产品,C[1,...]A[1,...]B[1,...]的kronecker产品。

对于d=2,这只是jnp.kron(或np.kron)函数所做的事情。

对于d=3,这可以通过jax.vmap来实现。jax.vmap(lambda x, y: jnp.kron(x[0, :], y[0, :]))(A, B)

但我无法找到一般(未知)维度的解决方案。有什么建议吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-09-10 17:20:26

numpy术语来说,我认为这就是您正在做的事情:

代码语言:javascript
复制
In [104]: A = np.arange(2*3*3).reshape(2,3,3)
In [105]: B = np.arange(2*4*4).reshape(2,4,4)

In [106]: C = np.array([np.kron(a,b) for a,b in zip(A,B)])
In [107]: C.shape
Out[107]: (2, 12, 12)

它将初始维度2作为一个batch来处理。一个明显的推广是对数组进行整形,将较高的维数降到1,例如reshape(-1,3,3)等,然后,将C重塑回所需的n维。

np.kron确实接受3d (甚至更高),但它在共享的二维上执行某种outer

代码语言:javascript
复制
In [108]: np.kron(A,B).shape
Out[108]: (4, 12, 12)

将这个四维可视化为(2,2),我可以拿diagonal,得到你的C

代码语言:javascript
复制
In [109]: np.allclose(np.kron(A,B)[[0,3]], C)
Out[109]: True

完整的kron做的计算比需要的要多,但速度仍然更快:

代码语言:javascript
复制
In [110]: timeit C = np.array([np.kron(a,b) for a,b in zip(A,B)])
108 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [111]: timeit np.kron(A,B)[[0,3]]
76.4 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

我相信可以更直接地进行计算,但要做到这一点,就需要更好地理解kron的工作原理。就像np.kron代码所暗示的那样,快速浏览一下outer(A,B)

代码语言:javascript
复制
In [114]: np.outer(A,B).shape
Out[114]: (18, 32)

它具有相同数量的元素,但随后它将reshapesconcatenates用于生成kron布局。

但根据直觉,我发现这相当于你想要的东西:

代码语言:javascript
复制
In [123]: D = A[:,:,None,:,None]*B[:,None,:,None,:]
In [124]: np.allclose(D.reshape(2,12,12),C)
Out[124]: True
In [125]: timeit np.reshape(A[:,:,None,:,None]*B[:,None,:,None,:],(2,12,12))
14.3 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

这很容易被推广到更多的领先维度。

代码语言:javascript
复制
def my_kron(A,B):
   D = A[...,:,None,:,None]*B[...,None,:,None,:]
   ds = D.shape
   newshape = (*ds[:-4],ds[-4]*ds[-3],ds[-2]*ds[-1])
   return D.reshape(newshape)

In [137]: my_kron(A.reshape(1,2,1,3,3),B.reshape(1,2,1,4,4)).shape
Out[137]: (1, 2, 1, 12, 12)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73673599

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档