我知道划分一个整数的问题已经很久了,这里有很多关于它的问题和答案,但是经过广泛的搜索之后,我还没有找到我想要的东西。公平地说,我的解决方案并不太糟糕,但我想知道是否有一种更快/更好的方法来完成以下工作:
我需要将一个整数划分为一个固定长度的分区,该分区可能包含值0,并且分区中的每个“位置”都受一个最大可能值的限制。例如:
>>>list(partition(number = 5, max_vals = (1,0,3,4)))
[(1, 0, 3, 1),
(1, 0, 2, 2),
(1, 0, 0, 4),
(1, 0, 1, 3),
(0, 0, 1, 4),
(0, 0, 2, 3),
(0, 0, 3, 2)]我的解决办法如下:
from collections import Counter
from itertools import combinations
def partition(number:int, max_vals:tuple):
S = set(combinations((k for i,val in enumerate(max_vals) for k in [i]*val), number))
for s in S:
c = Counter(s)
yield tuple([c[n] for n in range(len(max_vals))])本质上,我首先为每个插槽创建“令牌”,然后组合正确的数目,最后计算每个插槽有多少个。
我并不特别喜欢为每个分区实例化一个Counter,但我最不喜欢的是,combinations生成的元组比所需的要多得多,然后我用set()放弃了所有的复制,这看起来效率很低。有更好的办法吗?
发布于 2021-01-25 18:47:43
即使必须有更好的算法,使用itertools.product的相对更简单和更快的解决方案是:
>>> from itertools import product
>>> def partition_2(number:int, max_vals:tuple):
return (comb for comb in
product(*(range(min(number, i) + 1) for i in max_vals))
if sum(comb)==number)
>>> list(partition_2(number = 5, max_vals = (1,0,3,4)))
[(0, 0, 1, 4),
(0, 0, 2, 3),
(0, 0, 3, 2),
(1, 0, 0, 4),
(1, 0, 1, 3),
(1, 0, 2, 2),
(1, 0, 3, 1)]性能:
>>> %timeit list(partition(number = 15, max_vals = (1,0,3,4)*3))
155 ms ± 681 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit list(partition_2(number = 15, max_vals = (1,0,3,4)*3))
14.7 ms ± 763 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
################################################################################
>>> %timeit list(partition(number = 5, max_vals = (10,20,30,10,10)))
1.17 s ± 26.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit list(partition_2(number = 5, max_vals = (10,20,30,10,10)))
1.21 ms ± 28.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
#################################################################################
>>> %timeit list(partition_2(number = 35, max_vals = (8,9,10,11,12)))
23.2 ms ± 697 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit list(partition(number = 35, max_vals = (8,9,10,11,12)))
# Will update when/if it finishes :)发布于 2021-01-25 23:15:16
递归函数通常是处理此类问题的一种优雅方法:
def partition(N,slots):
if len(slots)==1:
if slots[0]>=N: yield [N]
return
for s in range(min(N,slots[0])+1):
yield from ([s]+p for p in partition(N-s,slots[1:]))
for part in partition(5,[1,0,3,4]): print(part)
[0, 0, 1, 4]
[0, 0, 2, 3]
[0, 0, 3, 2]
[1, 0, 0, 4]
[1, 0, 1, 3]
[1, 0, 2, 2]
[1, 0, 3, 1] 这可以通过检查每个递归级别的剩余空间和短路遍历来进一步优化,当剩余的插槽不足以扩展该数目时:
def partition(N,slots,space=None):
if space is None: space = sum(slots)
if N>space: return
if len(slots)==1:
if slots[0]>=N: yield [N]
return
for s in range(min(N,slots[0])+1):
yield from ([s]+p for p in partition(N-s,slots[1:],space-slots[0]))在解决方案的数量少于所有插槽的全部产品的情况下,这种优化提高了性能。在大多数插槽组合工作的情况下,它比迭代要慢。
from timeit import timeit
t = timeit(lambda:list(partition(45,(8,9,10,11,12))),number=1)
print(t) # 0.000679596
t = timeit(lambda:list(partition_2(45,(8,9,10,11,12))),number=1)
print(t) # 0.027492302 (Sayandip's)
t = timeit(lambda:list(partition(15,(1,0,3,4)*3)),number=1)
print(t) # 0.024383259
t = timeit(lambda:list(partition_2(15,(1,0,3,4)*3)),number=1)
print(t) # 0.018362536为了从递归方法中获得更好的系统性能,我们需要限制递归的深度。这可以通过以不同的方式处理问题来实现。如果我们将插槽分成两组,并确定两个组合槽(左和右)之间的分布,那么我们就可以在两边应用分区,并将结果组合起来。这只会恢复到Log2N的深度,并将大块合并在一起,而不是一次只添加一个值:
from itertools import product
def partition(N,slots,space=None):
if space is not None and N>space: return
if len(slots)==1:
if slots[0]>=N: yield [N]
return
if len(slots)==2:
for left in range(max(0,N-slots[1]),min(N,slots[0])+1):
yield [left,N-left]
return
leftSlots = slots[:len(slots)//2]
rightSlots = slots[len(slots)//2:]
leftSpace,rightSpace = sum(leftSlots),sum(rightSlots)
for leftN,rightN in partition(N,[leftSpace,rightSpace],leftSpace+rightSpace):
partLeft = partition(leftN, leftSlots, leftSpace)
partRight = partition(rightN, rightSlots, rightSpace)
for leftSide,rightSide in product(partLeft,partRight):
yield leftSide+rightSide在所有情况下,性能改进都是系统的:
t = timeit(lambda:list(partition(45,(8,9,10,11,12))),number=1)
print(t) # 0.00017742
t = timeit(lambda:list(partition_2(45,(8,9,10,11,12))),number=1)
print(t) # 0.02895038
t = timeit(lambda:list(partition(15,(1,0,3,4)*3)),number=1)
print(t) # 0.00338676
t = timeit(lambda:list(partition_2(15,(1,0,3,4)*3)),number=1)
print(t) # 0.02025453https://stackoverflow.com/questions/65889983
复制相似问题