首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >具有重复索引的numpy数组的矢量化赋值(d[i,j,i,j] = s[i,j])

具有重复索引的numpy数组的矢量化赋值(d[i,j,i,j] = s[i,j])
EN

Stack Overflow用户
提问于 2017-05-09 02:45:11
回答 2查看 112关注 0票数 3

如何设置

代码语言:javascript
复制
d[i,j,i,j] = s[i,j]

使用"NumPy“而不使用for循环?

我尝试过以下几种方法:

代码语言:javascript
复制
l1=range(M)
l2=range(N)
d[l1,l2,l1,l2] = s[l1,l2]
EN

回答 2

Stack Overflow用户

发布于 2017-05-09 02:54:13

您可以使用integer array indexing (使用np.ix_创建广播索引):

代码语言:javascript
复制
d[np.ix_(l1,l2)*2] = s[np.ix_(l1,l2)]

第一次必须复制索引(您需要[i, j, i, j]而不仅仅是[i, j]),这就是我将np.ix_返回的tuple乘以2的原因。

例如:

代码语言:javascript
复制
>>> d = np.zeros((10, 10, 10, 10), dtype=int)
>>> s = np.arange(100).reshape(10, 10)
>>> l1 = range(3)
>>> l2 = range(5)
>>> d[np.ix_(l1,l2)*2] = s[np.ix_(l1,l2)]

并确保分配了正确的值:

代码语言:javascript
复制
>>> # Assert equality for the given condition
>>> for i in l1:
...     for j in l2:
...         assert d[i, j, i, j] == s[i, j]

>>> # Interactive tests
>>> d[0, 0, 0, 0], s[0, 0]
(0, 0)
>>> d[1, 2, 1, 2], s[1, 2]
(12, 12)
>>> d[2, 0, 2, 0], s[2, 0]
(20, 20)
>>> d[2, 4, 2, 4], s[2, 4]
(24, 24)
票数 1
EN

Stack Overflow用户

发布于 2017-05-09 03:26:17

如果你想一想,这就等同于创建一个形状为(m*n, m*n)2D数组,并将s中的值赋给对角线位置。为了让最终的输出成为4D,我们只需要在最后重塑一下。这基本上在下面实现-

代码语言:javascript
复制
m,n = s.shape
d = np.zeros((m*n,m*n),dtype=s.dtype)
d.ravel()[::m*n+1] = s.ravel()
d.shape = (m,n,m,n)

运行时测试

方法-

代码语言:javascript
复制
# @MSeifert's solution
def assign_vals_ix(s):    
    d = np.zeros((m, n, m, n), dtype=s.dtype)
    l1 = range(m)
    l2 = range(n)
    d[np.ix_(l1,l2)*2] = s[np.ix_(l1,l2)]
    return d

# Proposed in this post
def assign_vals(s):
    m,n = s.shape
    d = np.zeros((m*n,m*n),dtype=s.dtype)
    d.ravel()[::m*n+1] = s.ravel()
    return d.reshape(m,n,m,n)

# Using a strides based approach
def assign_vals_strides(a):
    m,n = a.shape
    p,q = a.strides

    d = np.zeros((m,n,m,n),dtype=a.dtype)
    out_strides = (q*(n*m*n+n),(m*n+1)*q)
    d_view = np.lib.stride_tricks.as_strided(d, (m,n), out_strides)
    d_view[:] = a
    return d

计时-

代码语言:javascript
复制
In [285]: m,n = 10,10
     ...: s = np.random.rand(m,n)
     ...: d = np.zeros((m,n,m,n))
     ...: 

In [286]: %timeit assign_vals_ix(s)
10000 loops, best of 3: 21.3 µs per loop

In [287]: %timeit assign_vals_strides(s)
100000 loops, best of 3: 9.37 µs per loop

In [288]: %timeit assign_vals(s)
100000 loops, best of 3: 4.13 µs per loop

In [289]: m,n = 20,20
     ...: s = np.random.rand(m,n)
     ...: d = np.zeros((m,n,m,n))


In [290]: %timeit assign_vals_ix(s)
10000 loops, best of 3: 60.2 µs per loop

In [291]: %timeit assign_vals_strides(s)
10000 loops, best of 3: 41.8 µs per loop

In [292]: %timeit assign_vals(s)
10000 loops, best of 3: 35.5 µs per loop
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43855086

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档