首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Numpy文档中的Tensordot解释

Numpy文档中的Tensordot解释
EN

Stack Overflow用户
提问于 2021-10-31 20:21:08
回答 1查看 32关注 0票数 0

我不明白tensordot是如何工作的,我正在阅读官方文档,但我完全不明白那里发生了什么。

代码语言:javascript
复制
a = np.arange(60.).reshape(3,4,5)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a,b, axes=([1,0],[0,1]))
c.shape
(5, 2)

为什么形状是(5,2)?到底发生了什么?

我也读过this article,但答案把我搞糊涂了。

代码语言:javascript
复制
 In [7]: A = np.random.randint(2, size=(2, 6, 5))
   ...:  B = np.random.randint(2, size=(3, 2, 4))
   ...: 
代码语言:javascript
复制
In [9]: np.tensordot(A, B, axes=((0),(1))).shape
Out[9]: (6, 5, 3, 4)

A : (2, 6, 5) -> reduction of axis=0
B : (3, 2, 4) -> reduction of axis=1

Output : `(2, 6, 5)`, `(3, 2, 4)` ===(2 gone)==> `(6,5)` + `(3,4)` => `(6,5,3,4)`

为什么形状是(6, 5, 3, 4)

EN

回答 1

Stack Overflow用户

发布于 2021-10-31 23:32:41

代码语言:javascript
复制
In [196]: a = np.arange(60.).reshape(3,4,5)
     ...: b = np.arange(24.).reshape(4,3,2)
     ...: c = np.tensordot(a,b, axes=([1,0],[0,1]))
In [197]: c
Out[197]: 
array([[4400., 4730.],
       [4532., 4874.],
       [4664., 5018.],
       [4796., 5162.],
       [4928., 5306.]])

我发现einsum的等价物更容易“阅读”:

代码语言:javascript
复制
In [198]: np.einsum('ijk,jil->kl',a,b)
Out[198]: 
array([[4400., 4730.],
       [4532., 4874.],
       [4664., 5018.],
       [4796., 5162.],
       [4928., 5306.]])

tensordot通过转置和重塑输入来将问题简化为简单的dot

代码语言:javascript
复制
In [204]: a1 = a.transpose(2,1,0).reshape(5,12)
In [205]: b1 = b.reshape(12,2)
In [206]: np.dot(a1,b1)        # or a1@b1
Out[206]: 
array([[4400., 4730.],
       [4532., 4874.],
       [4664., 5018.],
       [4796., 5162.],
       [4928., 5306.]])

tensordot可以对结果进行进一步的操作,但这里不需要这样做。

在我得到正确的a1/b1之前,我必须尝试几件事。例如,a.transpose(2,0,1).reshape(5,12)会生成正确的形状,但值不同。

还有另一个版本:

代码语言:javascript
复制
In [210]: (a.transpose(1,0,2)[:,:,:,None]*b[:,:,None,:]).sum((0,1))
Out[210]: 
array([[4400., 4730.],
       [4532., 4874.],
       [4664., 5018.],
       [4796., 5162.],
       [4928., 5306.]])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69790298

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档