首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >numpy数组的累积argmax

numpy数组的累积argmax
EN

Stack Overflow用户
提问于 2016-11-18 08:11:44
回答 3查看 959关注 0票数 5

考虑数组a

代码语言:javascript
复制
np.random.seed([3,1415])
a = np.random.randint(0, 10, (10, 2))
a

array([[0, 2],
       [7, 3],
       [8, 7],
       [0, 6],
       [8, 6],
       [0, 2],
       [0, 4],
       [9, 7],
       [3, 2],
       [4, 3]])

什么是矢量化的方法来获得累积的argmax?

代码语言:javascript
复制
array([[0, 0],  <-- both start off as max position
       [1, 1],  <-- 7 > 0 so 1st col = 1, 3 > 2 2nd col = 1
       [2, 2],  <-- 8 > 7 1st col = 2, 7 > 3 2nd col = 2
       [2, 2],  <-- 0 < 8 1st col stays the same, 6 < 7 2nd col stays the same
       [2, 2],  
       [2, 2],
       [2, 2],
       [7, 2],  <-- 9 is new max of 2nd col, argmax is now 7
       [7, 2],
       [7, 2]])

这是一种非矢量化的方法。

注意,随着窗口的扩展,argmax应用于正在增长的窗口。

代码语言:javascript
复制
pd.DataFrame(a).expanding().apply(np.argmax).astype(int).values

array([[0, 0],
       [1, 1],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [7, 2],
       [7, 2],
       [7, 2]])
EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2016-11-18 11:28:54

下面是一个矢量化的纯NumPy解决方案,它执行得非常快:

代码语言:javascript
复制
def cumargmax(a):
    m = np.maximum.accumulate(a)
    x = np.repeat(np.arange(a.shape[0])[:, None], a.shape[1], axis=1)
    x[1:] *= m[:-1] < m[1:]
    np.maximum.accumulate(x, axis=0, out=x)
    return x

然后我们有:

代码语言:javascript
复制
>>> cumargmax(a)
array([[0, 0],
       [1, 1],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [2, 2],
       [7, 2],
       [7, 2],
       [7, 2]])

在具有数千到数百万个值的数组上进行的一些快速测试表明,这比在Python级别(隐式或显式)循环的速度要快10-50倍。

票数 7
EN

Stack Overflow用户

发布于 2016-11-18 08:36:35

我想不出一种方法可以轻松地对这两列进行矢量化;但是,如果列的数量相对于行数来说很小,那就不应该是一个问题,而for循环应该可以满足这个轴的要求:

代码语言:javascript
复制
import numpy as np
import numpy_indexed as npi
a = np.random.randint(0, 10, (10))
max = np.maximum.accumulate(a)
idx = npi.indices(a, max)
print(idx)
票数 1
EN

Stack Overflow用户

发布于 2016-11-18 09:03:02

我想要创建一个函数来计算1d数组的累积argmax,然后将其应用于所有列。这是代码:

代码语言:javascript
复制
import numpy as np

np.random.seed([3,1415])
a = np.random.randint(0, 10, (10, 2))

def cumargmax(v):
    uargmax = np.frompyfunc(lambda i, j: j if v[j] > v[i] else i, 2, 1)
    return uargmax.accumulate(np.arange(0, len(v)), 0, dtype=np.object).astype(v.dtype)

np.apply_along_axis(cumargmax, 0, a)

转换为np.object然后再转换回Numpy 1.9的原因是Numpy 1.9的一个解决方法,如generalized cumulative functions in NumPy/SciPy?中所提到的

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/40672186

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档