首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Numpy -如何删除尾随的N*8零

Numpy -如何删除尾随的N*8零
EN

Stack Overflow用户
提问于 2018-10-18 18:41:44
回答 3查看 152关注 0票数 0

我有一维数组,我需要删除所有8个零的尾随块。

代码语言:javascript
复制
[0,1,1,0,1,0,0,0, 0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0]
->
[0,1,1,0,1,0,0,0]

a.shape[0] % 8 == 0总是这样,所以不用担心。

有更好的方法吗?

代码语言:javascript
复制
import numpy as np
P = 8
arr1 = np.random.randint(2,size=np.random.randint(5,10) * P)
arr2 = np.random.randint(1,size=np.random.randint(5,10) * P)
arr = np.concatenate((arr1, arr2))

indexes = []
arr = np.flip(arr).reshape(arr.shape[0] // P, P)

for i, f in enumerate(arr):
    if (f == 0).all():
        indexes.append(i)
    else:
        break

arr = np.delete(arr, indexes, axis=0)
arr = np.flip(arr.reshape(arr.shape[0] * P))
EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2018-10-18 19:21:00

通过使用视图和np.argmax获取最后一个非零元素,您可以在不分配更多空间的情况下这样做:

代码语言:javascript
复制
index = arr.size - np.argmax(arr[::-1])

舍入到最接近8的倍数是很容易的:

代码语言:javascript
复制
index = np.ceil(index / 8) * 8

现在把剩下的砍下来:

代码语言:javascript
复制
arr = arr[:index]

或者是一字一句:

代码语言:javascript
复制
arr = arr[:(arr.size - np.argmax(arr[::-1])) / 8) * 8]

这个版本在时间上是O(n),在空间上是O(1),因为它对所有东西(包括输出)都重用相同的缓冲区。

这有一个额外的优点,即使没有尾随零,它也能正确工作。但是,使用argmax确实依赖于所有的元素是相同的。如果不是这样的话,您将需要首先计算一个掩码,例如使用arr.astype(bool)

如果您想使用您的原始方法,也可以将其向量化,尽管会有更多的开销:

代码语言:javascript
复制
view = arr.reshape(-1, 8)
mask = view.any(axis = 1)
index = view.shape[0] - np.argmax(mask[::-1])
arr = arr[:index * 8]
票数 0
EN

Stack Overflow用户

发布于 2018-10-18 20:09:11

有一个numpy函数几乎可以实现您想要的np.trim_zeros。我们可以利用它:

代码语言:javascript
复制
import numpy as np

def trim_mod(a, m=8):
    t = np.trim_zeros(a, 'b')
    return a[:len(a)-(len(a)-len(t))//m*m]

def test(a, t, m=8):
    assert (len(a) - len(t)) % m == 0
    assert len(t) < m or np.any(t[-m:])
    assert not np.any(a[len(t):])

for _ in range(1000):
    a = (np.random.random(np.random.randint(10, 100000))<0.002).astype(int)
    m = np.random.randint(4, 20)
    t = trim_mod(a, m)
    test(a, t, m)

print("Looks correct")

指纹:

代码语言:javascript
复制
Looks correct

它似乎在尾随零点的数量中线性地缩放:

但是从绝对的角度来看(单位是毫秒),所以np.trim_zeros可能只是一个python循环。

图片代码:

代码语言:javascript
复制
from timeit import timeit

A = (np.random.random(1000000)<0.02).astype(int)
m = 8
T = []
for last in range(1, 1000, 9):
    A[-last:] = 0
    A[-last] = 1
    T.append(timeit(lambda: trim_mod(A, m), number=100)*10)

import pylab
pylab.plot(range(1, 1000, 9), T)
pylab.show()
票数 0
EN

Stack Overflow用户

发布于 2018-10-20 18:15:11

低层次办法:

代码语言:javascript
复制
import numba
@numba.njit
def trim8(a):
    n=a.size-1
    while n>=0 and a[n]==0 : n-=1
    c= (n//8+1)*8
    return a[:c]

一些测试:

代码语言:javascript
复制
In [194]: A[-1]=1  # best case

In [196]: %timeit trim_mod(A,8)
5.7 µs ± 323 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [197]: %timeit trim8(A)
714 ns ± 33.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [198]: %timeit A[:(A.size - np.argmax(A[::-1]) // 8) * 8]
4.83 ms ± 479 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [202]: A[:]=0 #worst case

In [203]: %timeit trim_mod(A,8)
2.5 s ± 49.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [204]: %timeit trim8(A)
1.14 ms ± 71.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [205]: %timeit A[:(A.size - np.argmax(A[::-1]) // 8) * 8]
5.5 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

它有像trim_zeros这样的短路机制,但速度更快。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52880541

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档