arr1 = np.arange(8).reshape(4, 2)
arr2 = np.arange(4, 12).reshape(2, 4)
ans=np.tensordot(arr1,arr2,axes=([1],[0]))
ans2=np.tensordot(arr1,arr2,axes=([0],[1]))
ans3 = np.tensordot(arr1,arr2, axes=([1,0],[0,1]))我正在尝试理解这个tensordot函数是如何工作的。我知道它返回tensordot乘积。
但是轴的部分对我来说有点难以理解。我观察到的是
对于ans,它就像数组arr1中的列数和arr2中的行数构成了最终的矩阵。
对于ans2,arr2中的列数和arr1中的行数正好相反
我不理解axes=(1,0,0,1)。让我知道我对ans和ans2的理解是否正确
发布于 2021-01-16 00:59:18
您忘了显示数组:
In [87]: arr1
Out[87]:
array([[0, 1],
[2, 3],
[4, 5],
[6, 7]])
In [88]: arr2
Out[88]:
array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
In [89]: ans
Out[89]:
array([[ 8, 9, 10, 11],
[ 32, 37, 42, 47],
[ 56, 65, 74, 83],
[ 80, 93, 106, 119]])
In [90]: ans2
Out[90]:
array([[ 76, 124],
[ 98, 162]])
In [91]: ans3
Out[91]: array(238)ans只是常规的点阵乘积:
In [92]: np.dot(arr1,arr2)
Out[92]:
array([[ 8, 9, 10, 11],
[ 32, 37, 42, 47],
[ 56, 65, 74, 83],
[ 80, 93, 106, 119]])dot乘积和是在arr1的([1],[0])轴1和arr2的轴0上执行的(传统的是跨列、向下行)。用2d 'sum ...‘短语可能会让人感到困惑。在处理1或3d数组时,这一点更清晰。这里将匹配大小为2的维度相加,留下(4,4)。
ans2反转它们,对4求和,产生一个(2,2):
In [94]: np.dot(arr2,arr1)
Out[94]:
array([[ 76, 98],
[124, 162]])tensordot刚刚调换了两个数组,并执行了常规的dot
In [95]: np.dot(arr1.T,arr2.T)
Out[95]:
array([[ 76, 124],
[ 98, 162]])ans3使用转置和整形(ravel),在两个轴上求和:
In [98]: np.dot(arr1.ravel(),arr2.T.ravel())
Out[98]: 238通常,tensordot混合使用转置和重塑来将问题简化为2d np.dot问题。然后,它可能会重塑和转置结果。
我发现einsum的维度控制更加清晰:
In [99]: np.einsum('ij,jk->ik',arr1,arr2)
Out[99]:
array([[ 8, 9, 10, 11],
[ 32, 37, 42, 47],
[ 56, 65, 74, 83],
[ 80, 93, 106, 119]])
In [100]: np.einsum('ji,kj->ik',arr1,arr2)
Out[100]:
array([[ 76, 124],
[ 98, 162]])
In [101]: np.einsum('ij,ji',arr1,arr2)
Out[101]: 238随着einsum和matmul/@的发展,tensordot变得越来越不必要。它更难理解,并且没有任何速度或灵活性优势。不要担心理解它。
ans3是其他两个ans的迹线(对角线之和):
In [103]: np.trace(ans)
Out[103]: 238
In [104]: np.trace(ans2)
Out[104]: 238发布于 2021-01-15 19:27:29
根据我对tensordot文档的理解,您在ans、ans2和ans3中提供了一个axis列表(ans和ans2在列表中只有一个元素)。然后,该列表指定要对哪些轴求和。您对ans和ans2的假设是正确的,在ans中,您的第一个元素是arr1的0轴(arr1中的行)和arr2中的1轴(arr2中的列)。我不能完全确定对ans3有什么期望,但我可能会尝试自己运行一些示例,并查看一下。我希望这能让你更好地理解
链接:https://numpy.org/doc/stable/reference/generated/numpy.tensordot.html
https://stackoverflow.com/questions/65735024
复制相似问题