假设我有一个大小为(N*N, N*N)的(稀疏)矩阵。我想从这个矩阵中选择元素,其中grid (一个(n,m)数组,其中n*m=N)的外积是True (它是一个布尔型2D数组,并且是na=grid.sum())。这可以按如下方式完成
result = M[np.outer( grid.flatten(),grid.flatten() )].reshape (( N, N ) )result是一个(na,na)稀疏数组(和na < N)。前一行是我想要实现的:从grid的乘积中获取M中为真的元素,并从数组中剔除不为真的元素。
随着n和m (以及N)的增长,而M和result是稀疏矩阵,我无法在内存或速度方面有效地做到这一点。我试过的最接近的是:
result = sp.lil_matrix ( (1, N*N), dtype=np.float32 )
# Calculate outer product
A = np.einsum("i,j", grid.flatten(), grid.flatten())
cntr = 0
it = np.nditer ( A, flags=['multi_index'] )
while not it.finished:
if it[0]:
result[0,cntr] = M[it.multi_index[0], it.multi_index[1]]
cntr += 1
# reshape result to be a N*N sparse matrix最后一次重塑可以由this approach完成,但我还没有完成,因为while循环将永远花费时间。
我也尝试过选择A的非零元素,并循环遍历,但这会耗尽所有内存:
A=np.einsum("i,j", grid.flatten(), grid.flatten())
nzero = A.nonzero() # This eats lots of memory
cntr = 0
for (i,j) in zip (*nzero):
temp_mat[0,cntr] = M[i,j]
cnt += 1上面例子中的“n”和“m”大约是300。
发布于 2016-07-23 13:45:13
我不知道是打字错误还是代码错误,但是您的示例缺少一个iternext
R=[]
it = np.nditer ( A, flags=['multi_index'] )
while not it.finished:
if it[0]:
R.append(M[it.multi_index])
it.iternext()我认为追加到列表中比R[ctnr]=...更简单、更快。如果R是一个规则数组,并且稀疏索引速度较慢(即使是最快的lil格式),这是很有竞争力的。
ndindex将nditer的这种用法包装为:
R=[]
for index in np.ndindex(A.shape):
if A[index]:
R.append(M[index])ndenumerate还可以工作:
R = []
for index,a in np.ndenumerate(A):
if a:
R.append(M[index])但我想知道,您是否真的想推进it的每一步,而不仅仅是True的情况。否则,将result重塑为(N,N)没有多大意义。但在这种情况下,你的问题不就是
M[:N, :N].multiply(A)或者,如果M是一个密集数组:
M[:N, :N]*A实际上,如果M和A都是稀疏的,那么该multiply的.data属性将与R列表相同。
In [76]: N=4
In [77]: M=np.arange(N*N*N*N).reshape(N*N,N*N)
In [80]: a=np.array([0,1,0,1])
In [81]: A=np.einsum('i,j',a,a)
In [82]: A
Out[82]:
array([[0, 0, 0, 0],
[0, 1, 0, 1],
[0, 0, 0, 0],
[0, 1, 0, 1]])
In [83]: M[:N, :N]*A
Out[83]:
array([[ 0, 0, 0, 0],
[ 0, 17, 0, 19],
[ 0, 0, 0, 0],
[ 0, 49, 0, 51]])
In [84]: c=sparse.csr_matrix(M)[:N,:N].multiply(sparse.csr_matrix(A))
In [85]: c.data
Out[85]: array([17, 19, 49, 51], dtype=int32)
In [89]: [M[index] for index, a in np.ndenumerate(A) if a]
Out[89]: [17, 19, 49, 51]https://stackoverflow.com/questions/38535861
复制相似问题