对于当前的项目,我必须用相同的矩阵计算许多向量的内积(这是相当稀疏的)。向量与二维网格相关联,因此我将向量存储在一个三维数组中:
例如:
X是一个由(I,J,N)组成的数组。矩阵A为dim (N,N)。现在的任务是为A.dot(X[i,j])中的每个i,j计算I,J。
对于numpy数组,这很容易用
Y = X.dot(A.T) 现在,我想将A存储为稀疏矩阵,因为它是稀疏的,并且只包含非常有限的非零条目,这会导致许多不必要的乘法。不幸的是,上面的解决方案不能工作,因为numpy点不适用于稀疏矩阵。据我所知,对于枕骨稀疏,没有张力点式的手术。
有人知道用稀疏矩阵Y计算上述数组A的好方法吗?
发布于 2013-09-19 19:54:18
显而易见的方法是在向量上运行一个循环,并使用稀疏矩阵的.dot方法:
def naive_sps_x_dense_vecs(sps_mat, dense_vecs):
rows, cols = sps_mat.shape
I, J, _ = dense_vecs.shape
out = np.empty((I, J, rows))
for i in xrange(I):
for j in xrange(J):
out[i, j] = sps_mat.dot(dense_vecs[i, j])
return out但是,通过将3d数组重组为2d,并避免Python循环,您可以稍微加快速度:
def sps_x_dense_vecs(sps_mat, dense_vecs):
rows, cols = sps_mat.shape
vecs_shape = dense_vecs.shape
dense_vecs = dense_vecs.reshape(-1, cols)
out = sps_mat.dot(dense_vecs.T).T
return out.reshape(vecs.shape[:-1] + (rows,))问题是,我们需要将稀疏矩阵作为第一个参数,这样我们就可以调用它的.dot方法,这意味着返回被转置,这意味着在转置之后,最后一个整形将触发整个数组的副本。因此,对于相当大的I和J值,再加上不太大的N值,后者的速度将是前者的几倍,但对于其他参数组合,性能甚至可能相反:
n, i, j = 100, 500, 500
a = sps.rand(n, n, density=1/n, format='csc')
vecs = np.random.rand(i, j, n)
>>> np.allclose(naive_sps_x_dense_vecs(a, vecs), sps_x_dense_vecs(a, vecs))
True
n, i, j = 100, 500, 500
%timeit naive_sps_x_dense_vecs(a, vecs)
1 loops, best of 3: 3.85 s per loop
%timeit sps_x_dense_vecs(a, vecs)
1 loops, best of 3: 576 ms per
n, i, j = 1000, 200, 200
%timeit naive_sps_x_dense_vecs(a, vecs)
1 loops, best of 3: 791 ms per loop
%timeit sps_x_dense_vecs(a, vecs)
1 loops, best of 3: 1.3 s per loop发布于 2022-04-20 01:30:05
您可以使用jax来实现您想要的目标。让我们假设稀疏矩阵是csr_array格式的。首先将其转换为jax BCOO array
from scipy import sparse
from jax.experimental import sparse as jaxsparse
import jax.numpy as jnp
def convert_to_BCOO(x):
x = x.transpose() #get the transpose
x = x.tocoo()
x = jaxsparse.BCOO((x.data, jnp.column_stack((x.row, x.col))),
shape=x.shape)
x = L.sort_indices()然后,您可以使用jax.sparsify创建一个稀疏点产品,如下所示。
def dot(x, y):
return jnp.dot(x, y)
sp_dot = jaxsparse.sparsify(dot)
A_transpose = convert_to_BCOO(A)
Y = sp_dot(X,A_transpose)函数sp_dot现在遵循与numpy.dot完全相同的规则。
希望这能有所帮助!
https://stackoverflow.com/questions/18901938
复制相似问题