一个numpy einsum语句能复制gemm功能吗?标量和矩阵乘法看起来很简单,但我还没有找到如何让"+“工作的方法。如果它更简单,D=α*A*B+ beta *C是可以接受的(实际上更可取)
alpha = 2
beta = 3
A = np.arange(9).reshape(3, 3)
B = A + 1
C = B + 1
left_part = alpha*np.dot(A, B)
print(left_part)
left_part = np.einsum(',ij,jk->ik', alpha, A, B)
print(left_part)发布于 2016-11-03 19:16:23
这里似乎有些混乱:np.einsum处理可以以以下形式转换的操作:广播-乘-减少。按元素进行的求和不是其范围的一部分。
您之所以需要这样的乘法操作,是因为“天真地”写出这些操作可能会很快超过内存或计算资源。例如,考虑矩阵乘法:
import numpy as np
x, y = np.ones((2, 2000, 2000))
# explicit loop - ridiculously slow
a = sum(x[:,j,np.newaxis] * y[j,:] for j in range(2000))
# explicit broadcast-multiply-reduce: throws MemoryError
a = (x[:,:,np.newaxis] * y[:,np.newaxis,:]).sum(1)
# einsum or dot: fast and memory-saving
a = np.einsum('ij,jk->ik', x, y)然而,爱因斯坦公约考虑了附加因素,因此您可以简单地将类似BLAS的问题写成:
d = np.einsum(',ij,jk->ik', alpha, a, b) + np.einsum(',ik', beta, c)以最小的内存开销(如果您真的关心内存,可以将其中的大部分重写为就地操作)和恒定的运行时开销(两个python调用的成本)。
因此,在性能方面,对我来说,这似乎是一种过早优化的情况:您是否已经验证过,将类似GEMM的操作拆分为两个单独的numpy调用是代码中的瓶颈?如果确实如此,我建议如下(按增加参与程度排列):
scipy.linalg.blas.dgemm。如果您的性能显著提高,我会感到惊讶,因为dgemm通常只是构建块本身。https://stackoverflow.com/questions/39976383
复制相似问题