首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensordot的性能瓶颈

Tensordot的性能瓶颈
EN

Stack Overflow用户
提问于 2018-01-31 03:59:34
回答 1查看 887关注 0票数 1

当我试图理解numpy.tensordot()时,我尝试了文档中的示例,并确信通过不同的axes参数排列,我们可以得到完全相同的axes编辑结果。例如,以下两个轴的排列是等价的(即它们都产生相同的结果):

代码语言:javascript
复制
In [28]: a = np.arange(60.).reshape(3,4,5)
In [29]: b = np.arange(24.).reshape(4,3,2)

In [30]: perm1 = np.tensordot(a, b, axes=[(1, 0), (0, 1)])
In [31]: perm2 = np.tensordot(a, b, axes=[(0, 1), (1, 0)])

In [32]: np.all(perm1 == perm2)
Out[32]: True

然而,在测量性能时,我发现一个排列比另一个的快2倍,这让我很困惑。

代码语言:javascript
复制
# setting up input arrays
In [19]: a = np.arange(30*40*50).reshape(30,40,50)
In [20]: b = np.arange(40*30*20).reshape(40,30,20)

# contracting the first two axes from the input tensors
In [21]: %timeit np.tensordot(a, b, axes=[(0, 1), (1, 0)])
3.23 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# an equivalent way of contraction of the first two
# axes from the input tensors as in the above case
In [22]: %timeit np.tensordot(a, b, axes=[(1, 0), (0, 1)])
1.62 ms ± 16.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

那么,在后一种情况下,2x加速的原因是什么?这与NumPy ndarrays在内存中的内部结构有关吗?还是别的什么?提前感谢您的见解!

EN

回答 1

Stack Overflow用户

发布于 2018-01-31 07:58:34

不详细介绍,这两种计算将重新创建tensordot所采取的操作,并生成相同的perm值。

它们显示出同样的2倍的速度差异:

代码语言:javascript
复制
In [24]: timeit np.dot(a.transpose(2,0,1).reshape(50,-1), b.transpose(1,0,2).reshape(-1,20))
4.39 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [25]: timeit np.dot(a.transpose(2,1,0).reshape(50,-1), b.reshape(-1,20))
2.99 ms ± 97.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我猜第二个更快,因为b.reshape(-1,20)不需要拷贝,而转置和第一个的整形是这样做的。

不同的重塑时机:

代码语言:javascript
复制
In [28]: timeit a.transpose(2,1,0).reshape(50,-1)
128 µs ± 978 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [29]: timeit a.transpose(2,0,1).reshape(50,-1)
1.04 µs ± 21.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [30]: timeit b.reshape(-1,20)
501 ns ± 14.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [31]: timeit b.transpose(1,0,2).reshape(-1,20)
27.5 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

在速度上有明显的差异。[30]只是一个view,所以这解释了为什么它这么快。我猜[28]的速度要慢得多,因为它涉及到元素的完全反转,其中[29]复制(40,50)块。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48534242

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档