首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用模式划分整数

用模式划分整数
EN

Stack Overflow用户
提问于 2021-01-25 17:58:31
回答 2查看 89关注 0票数 2

我知道划分一个整数的问题已经很久了,这里有很多关于它的问题和答案,但是经过广泛的搜索之后,我还没有找到我想要的东西。公平地说,我的解决方案并不太糟糕,但我想知道是否有一种更快/更好的方法来完成以下工作:

我需要将一个整数划分为一个固定长度的分区,该分区可能包含值0,并且分区中的每个“位置”都受一个最大可能值的限制。例如:

代码语言:javascript
复制
>>>list(partition(number = 5, max_vals = (1,0,3,4)))
[(1, 0, 3, 1),
 (1, 0, 2, 2),
 (1, 0, 0, 4),
 (1, 0, 1, 3),
 (0, 0, 1, 4),
 (0, 0, 2, 3),
 (0, 0, 3, 2)]

我的解决办法如下:

代码语言:javascript
复制
from collections import Counter
from itertools import combinations

def partition(number:int, max_vals:tuple):
    S = set(combinations((k for i,val in enumerate(max_vals) for k in [i]*val), number))
    for s in S:
        c = Counter(s)
        yield tuple([c[n] for n in range(len(max_vals))])

本质上,我首先为每个插槽创建“令牌”,然后组合正确的数目,最后计算每个插槽有多少个。

我并不特别喜欢为每个分区实例化一个Counter,但我最不喜欢的是,combinations生成的元组比所需的要多得多,然后我用set()放弃了所有的复制,这看起来效率很低。有更好的办法吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-01-25 18:47:43

即使必须有更好的算法,使用itertools.product的相对更简单和更快的解决方案是:

代码语言:javascript
复制
>>> from itertools import product
>>> def partition_2(number:int, max_vals:tuple):
        return (comb for comb in 
                product(*(range(min(number, i) + 1) for i in max_vals)) 
                if sum(comb)==number)

>>> list(partition_2(number = 5, max_vals = (1,0,3,4)))
[(0, 0, 1, 4),
 (0, 0, 2, 3),
 (0, 0, 3, 2),
 (1, 0, 0, 4),
 (1, 0, 1, 3),
 (1, 0, 2, 2),
 (1, 0, 3, 1)]

性能:

代码语言:javascript
复制
>>> %timeit list(partition(number = 15, max_vals = (1,0,3,4)*3))
155 ms ± 681 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

>>> %timeit list(partition_2(number = 15, max_vals = (1,0,3,4)*3))
14.7 ms ± 763 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
################################################################################
>>> %timeit list(partition(number = 5, max_vals = (10,20,30,10,10)))
1.17 s ± 26.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

>>> %timeit list(partition_2(number = 5, max_vals = (10,20,30,10,10)))
1.21 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
#################################################################################
>>> %timeit list(partition_2(number = 35, max_vals = (8,9,10,11,12)))
23.2 ms ± 697 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

>>> %timeit list(partition(number = 35, max_vals = (8,9,10,11,12)))
# Will update when/if it finishes :)
票数 1
EN

Stack Overflow用户

发布于 2021-01-25 23:15:16

递归函数通常是处理此类问题的一种优雅方法:

代码语言:javascript
复制
def partition(N,slots):
    if len(slots)==1:
        if slots[0]>=N: yield [N]
        return
    for s in range(min(N,slots[0])+1):
        yield from ([s]+p for p in partition(N-s,slots[1:]))

                    
for part in partition(5,[1,0,3,4]): print(part)
[0, 0, 1, 4]
[0, 0, 2, 3]
[0, 0, 3, 2]
[1, 0, 0, 4]
[1, 0, 1, 3]
[1, 0, 2, 2]
[1, 0, 3, 1]    

这可以通过检查每个递归级别的剩余空间和短路遍历来进一步优化,当剩余的插槽不足以扩展该数目时:

代码语言:javascript
复制
def partition(N,slots,space=None):
    if space is None: space = sum(slots)
    if N>space: return
    if len(slots)==1:
        if slots[0]>=N: yield [N]
        return
    for s in range(min(N,slots[0])+1):
        yield from ([s]+p for p in partition(N-s,slots[1:],space-slots[0]))

在解决方案的数量少于所有插槽的全部产品的情况下,这种优化提高了性能。在大多数插槽组合工作的情况下,它比迭代要慢。

代码语言:javascript
复制
from timeit import timeit

t = timeit(lambda:list(partition(45,(8,9,10,11,12))),number=1)
print(t) # 0.000679596

t = timeit(lambda:list(partition_2(45,(8,9,10,11,12))),number=1)
print(t) # 0.027492302 (Sayandip's)


t = timeit(lambda:list(partition(15,(1,0,3,4)*3)),number=1)
print(t) # 0.024383259

t = timeit(lambda:list(partition_2(15,(1,0,3,4)*3)),number=1)
print(t) # 0.018362536

为了从递归方法中获得更好的系统性能,我们需要限制递归的深度。这可以通过以不同的方式处理问题来实现。如果我们将插槽分成两组,并确定两个组合槽(左和右)之间的分布,那么我们就可以在两边应用分区,并将结果组合起来。这只会恢复到Log2N的深度,并将大块合并在一起,而不是一次只添加一个值:

代码语言:javascript
复制
from itertools import product
def partition(N,slots,space=None):
    if space is not None and N>space: return
    if len(slots)==1:
        if slots[0]>=N: yield [N]
        return
    if len(slots)==2:
        for left in range(max(0,N-slots[1]),min(N,slots[0])+1):
            yield [left,N-left]
        return
    leftSlots  = slots[:len(slots)//2]
    rightSlots = slots[len(slots)//2:]
    leftSpace,rightSpace = sum(leftSlots),sum(rightSlots)
    for leftN,rightN in partition(N,[leftSpace,rightSpace],leftSpace+rightSpace):
        partLeft  = partition(leftN,  leftSlots,  leftSpace)
        partRight = partition(rightN, rightSlots, rightSpace)
        for leftSide,rightSide in product(partLeft,partRight):
            yield leftSide+rightSide

在所有情况下,性能改进都是系统的:

代码语言:javascript
复制
t = timeit(lambda:list(partition(45,(8,9,10,11,12))),number=1)
print(t) # 0.00017742

t = timeit(lambda:list(partition_2(45,(8,9,10,11,12))),number=1)
print(t) # 0.02895038


t = timeit(lambda:list(partition(15,(1,0,3,4)*3)),number=1)
print(t) # 0.00338676

t = timeit(lambda:list(partition_2(15,(1,0,3,4)*3)),number=1)
print(t) # 0.02025453
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65889983

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档