首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >理解PyTorch einsum

理解PyTorch einsum
EN

Stack Overflow用户
提问于 2019-04-28 21:23:57
回答 1查看 24K关注 0票数 29

我熟悉合额在NumPy中的工作方式。PyTorch:torch.einsum()也提供了类似的功能。在功能或性能方面,有什么相似之处和不同之处?PyTorch文档中提供的信息非常少,没有提供任何有关这方面的见解。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-04-28 21:37:42

由于einsum的描述在torch文档中比较少,所以我决定写这篇文章来记录、比较和对比torch.einsum()numpy.einsum()的行为。

差异:

  • NumPy允许小写字母和大写字母[a-zA-Z]用于“下标字符串”,而PyTorch只允许小写字母[a-z]
  • NumPy接受nd-数组、普通Python列表(或元组)、列表列表(或元组的元组、元组的列表、列表的元组)甚至PyTorch张量作为操作数(即输入)。这是因为操作数只能是array_like,而不是严格的NumPy和数组。相反,PyTorch期望操作数(即输入)严格地是PyTorch张量。如果您传递普通Python /tuple(或其组合)或NumPy nd-数组,它将抛出一个NumPy。
  • NumPy除了支持nd-arrays之外,还支持很多关键字参数(例如optimize),而PyTorch还没有提供这样的灵活性。

以下是PyTorch和NumPy中一些示例的实现:

代码语言:javascript
复制
# input tensors to work with

In [16]: vec
Out[16]: tensor([0, 1, 2, 3])

In [17]: aten
Out[17]: 
tensor([[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]])

In [18]: bten
Out[18]: 
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]])

1)矩阵乘法

代码语言:javascript
复制
  PyTorch: `torch.matmul(aten, bten)` ; `aten.mm(bten)`
代码语言:javascript
复制
  NumPy : `np.einsum("ij, jk -> ik", arr1, arr2)` 
代码语言:javascript
复制
In [19]: torch.einsum('ij, jk -> ik', aten, bten)
Out[19]: 
tensor([[130, 130, 130, 130],
        [230, 230, 230, 230],
        [330, 330, 330, 330],
        [430, 430, 430, 430]])

2)沿主对角线提取元素。

代码语言:javascript
复制
PyTorch: `torch.diag(aten)`
代码语言:javascript
复制
NumPy : `np.einsum("ii -> i", arr)` 
代码语言:javascript
复制
In [28]: torch.einsum('ii -> i', aten)
Out[28]: tensor([11, 22, 33, 44])

3) Hadamard积(即两个张量的元素乘积)

代码语言:javascript
复制
PyTorch: `aten * bten`
代码语言:javascript
复制
NumPy : `np.einsum("ij, ij -> ij", arr1, arr2)` 
代码语言:javascript
复制
In [34]: torch.einsum('ij, ij -> ij', aten, bten)
Out[34]: 
tensor([[ 11,  12,  13,  14],
        [ 42,  44,  46,  48],
        [ 93,  96,  99, 102],
        [164, 168, 172, 176]])

4)元素向平方

代码语言:javascript
复制
PyTorch: `aten ** 2`
代码语言:javascript
复制
NumPy : `np.einsum("ij, ij -> ij", arr, arr)` 
代码语言:javascript
复制
In [37]: torch.einsum('ij, ij -> ij', aten, aten)
Out[37]: 
tensor([[ 121,  144,  169,  196],
        [ 441,  484,  529,  576],
        [ 961, 1024, 1089, 1156],
        [1681, 1764, 1849, 1936]])

General:元素级nth电源可以通过重复下标字符串和张量n时间来实现。例如,就计算单元而言,张量的第4次方可以使用以下方法完成:

代码语言:javascript
复制
# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
Out[38]: 
tensor([[  14641,   20736,   28561,   38416],
        [ 194481,  234256,  279841,  331776],
        [ 923521, 1048576, 1185921, 1336336],
        [2825761, 3111696, 3418801, 3748096]])

5)迹(即主对角线元素之和)

代码语言:javascript
复制
PyTorch: `torch.trace(aten)`
代码语言:javascript
复制
NumPy einsum: `np.einsum("ii -> ", arr)` 
代码语言:javascript
复制
In [44]: torch.einsum('ii -> ', aten)
Out[44]: tensor(110)

6)矩阵转置

代码语言:javascript
复制
PyTorch: `torch.transpose(aten, 1, 0)`
代码语言:javascript
复制
NumPy einsum: `np.einsum("ij -> ji", arr)` 
代码语言:javascript
复制
In [58]: torch.einsum('ij -> ji', aten)
Out[58]: 
tensor([[11, 21, 31, 41],
        [12, 22, 32, 42],
        [13, 23, 33, 43],
        [14, 24, 34, 44]])

7)外积(向量)

代码语言:javascript
复制
PyTorch: `torch.ger(vec, vec)`
代码语言:javascript
复制
NumPy einsum: `np.einsum("i, j -> ij", vec, vec)` 
代码语言:javascript
复制
In [73]: torch.einsum('i, j -> ij', vec, vec)
Out[73]: 
tensor([[0, 0, 0, 0],
        [0, 1, 2, 3],
        [0, 2, 4, 6],
        [0, 3, 6, 9]])

8) (向量的内积) PyTorch:torch.dot(vec1, vec2)

代码语言:javascript
复制
NumPy einsum: `np.einsum("i, i -> ", vec1, vec2)` 
代码语言:javascript
复制
In [76]: torch.einsum('i, i -> ', vec, vec)
Out[76]: tensor(14)

9)沿轴0的求和

代码语言:javascript
复制
PyTorch: `torch.sum(aten, 0)`
代码语言:javascript
复制
NumPy einsum: `np.einsum("ij -> j", arr)` 
代码语言:javascript
复制
In [85]: torch.einsum('ij -> j', aten)
Out[85]: tensor([104, 108, 112, 116])

10)沿轴1求和

代码语言:javascript
复制
 PyTorch: `torch.sum(aten, 1)`
代码语言:javascript
复制
 NumPy einsum: `np.einsum("ij -> i", arr)` 
代码语言:javascript
复制
In [86]: torch.einsum('ij -> i', aten)
Out[86]: tensor([ 50,  90, 130, 170])

11)批量矩阵乘法

代码语言:javascript
复制
 PyTorch: `torch.bmm(batch_tensor_1, batch_tensor_2)`
代码语言:javascript
复制
 NumPy  : `np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)` 
代码语言:javascript
复制
# input batch tensors to work with
In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4) 

In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)  
Out[15]: 
tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape 
Out[16]: torch.Size([2, 4, 4])

# batch matrix multiply using einsum
In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
Out[17]: 
tensor([[[  20,   23,   26,   29],
         [  56,   68,   80,   92],
         [  92,  113,  134,  155],
         [ 128,  158,  188,  218]],

        [[ 632,  671,  710,  749],
         [ 776,  824,  872,  920],
         [ 920,  977, 1034, 1091],
         [1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape

12)沿轴2的和

代码语言:javascript
复制
 PyTorch: `torch.sum(batch_ten, 2)`
代码语言:javascript
复制
 NumPy einsum: `np.einsum("ijk -> ij", arr3D)` 
代码语言:javascript
复制
In [99]: torch.einsum("ijk -> ij", batch_ten)
Out[99]: 
tensor([[ 50,  90, 130, 170],
        [  4,   8,  12,  16]])

13)将nD张量中的所有元素相加

代码语言:javascript
复制
 PyTorch: `torch.sum(batch_ten)`
代码语言:javascript
复制
 NumPy einsum: `np.einsum("ijk -> ", arr3D)` 
代码语言:javascript
复制
In [101]: torch.einsum("ijk -> ", batch_ten)
Out[101]: tensor(480)

14)多轴求和(即边际化)

代码语言:javascript
复制
 PyTorch: `torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))`
代码语言:javascript
复制
 NumPy: `np.einsum("ijklmnop -> n", nDarr)` 
代码语言:javascript
复制
# 8D tensor
In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
In [104]: nDten.shape
Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])

# marginalize out dimension 5 (i.e. "n" here)
In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
In [112]: esum
Out[112]: tensor([  98.6921, -206.0575])

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))

In [115]: torch.allclose(tsum, esum)
Out[115]: True

15)双点产品/ Frobenius内积 (同:torch.sum(hadamard- Products )) cf.3)

代码语言:javascript
复制
 PyTorch: `torch.sum(aten * bten)`
代码语言:javascript
复制
 NumPy  : `np.einsum("ij, ij -> ", arr1, arr2)` 
代码语言:javascript
复制
In [120]: torch.einsum("ij, ij -> ", aten, bten)
Out[120]: tensor(1300)
票数 81
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55894693

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档