如果我有numpy数组A和B,那么我可以用以下公式计算它们的矩阵乘积的迹:
tr = numpy.linalg.trace(A.dot(B))然而,当在跟踪中只使用对角元素时,矩阵乘法A.dot(B)不必要地计算矩阵乘积中的所有非对角线条目。相反,我可以这样做:
tr = 0.0
for i in range(n):
tr += A[i, :].dot(B[:, i])但是这在Python代码中执行循环,并且不像numpy.linalg.trace那样明显。
有没有更好的方法来计算numpy数组的矩阵乘积的迹?做这件事最快或最常用的方式是什么?
发布于 2013-09-18 00:39:25
您可以通过将中间存储减少到仅对角线元素来改进@Bill的解决方案:
from numpy.core.umath_tests import inner1d
m, n = 1000, 500
a = np.random.rand(m, n)
b = np.random.rand(n, m)
# They all should give the same result
print np.trace(a.dot(b))
print np.sum(a*b.T)
print np.sum(inner1d(a, b.T))
%timeit np.trace(a.dot(b))
10 loops, best of 3: 34.7 ms per loop
%timeit np.sum(a*b.T)
100 loops, best of 3: 4.85 ms per loop
%timeit np.sum(inner1d(a, b.T))
1000 loops, best of 3: 1.83 ms per loop另一种选择是使用np.einsum,并且根本没有显式的中间存储:
# Will print the same as the others:
print np.einsum('ij,ji->', a, b)在我的系统上,它的运行速度比使用inner1d稍慢,但它可能并不适用于所有系统,请参阅this question
%timeit np.einsum('ij,ji->', a, b)
100 loops, best of 3: 1.91 ms per loop发布于 2013-09-18 00:12:13
在wikipedia中,您可以使用hadamard乘积(逐元素乘法)计算轨迹:
# Tr(A.B)
tr = (A*B.T).sum()我认为这比使用numpy.trace(A.dot(B))需要更少的计算。
编辑:
运行了一些定时器。这种方式比使用numpy.trace快得多。
In [37]: timeit("np.trace(A.dot(B))", setup="""import numpy as np;
A, B = np.random.rand(1000,1000), np.random.rand(1000,1000)""", number=100)
Out[38]: 8.6434469223022461
In [39]: timeit("(A*B.T).sum()", setup="""import numpy as np;
A, B = np.random.rand(1000,1000), np.random.rand(1000,1000)""", number=100)
Out[40]: 0.5516049861907959发布于 2017-02-22 23:18:58
请注意,一个细微的变化是取vec变换矩阵的点积。在python中,矢量化是使用.flatten('F')完成的。在我的电脑上,它比求Hadamard乘积的和慢一点,所以这是一个比wflynny的更差的解决方案,但我认为它很有趣,因为在某些情况下,我认为它可能更直观。例如,就我个人而言,我发现对于矩阵正态分布,向量化的解对我来说更容易理解。
在我的系统上,速度比较:
import numpy as np
import time
N = 1000
np.random.seed(123)
A = np.random.randn(N, N)
B = np.random.randn(N, N)
tart = time.time()
for i in range(10):
C = np.trace(A.dot(B))
print(time.time() - start, C)
start = time.time()
for i in range(10):
C = A.flatten('F').dot(B.T.flatten('F'))
print(time.time() - start, C)
start = time.time()
for i in range(10):
C = (A.T * B).sum()
print(time.time() - start, C)
start = time.time()
for i in range(10):
C = (A * B.T).sum()
print(time.time() - start, C)结果:
6.246593236923218 -629.370798672
0.06539678573608398 -629.370798672
0.057890892028808594 -629.370798672
0.05709719657897949 -629.370798672https://stackoverflow.com/questions/18854425
复制相似问题