首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何有效地对术语进行分组以表示对λ的理解?

如何有效地对术语进行分组以表示对λ的理解?
EN

Stack Overflow用户
提问于 2019-08-30 19:39:04
回答 2查看 359关注 0票数 0

我有一个渐近多项式,它有一个非常多的项。我想抨击那个公式。然而,由于它有一个非常多的项,并且多项式是展开的,所以有更多的运算下降,而不是最优的。具体来说,通过将某些术语组合在一起,我们可以消除一些操作。考虑以下等式,例如:

代码语言:javascript
复制
x^2y^2 + x^2y + x^2 + 1

如果我对此进行抨击,那么,如果xy是长度为N的一维np.arrays,则会有4个元素方向的平方-ings,2个按元素方向的乘法,以及3个按元素方向的加法,从而产生大约9*N的运算。

OTOH,通过做一点代数,我们得出:

代码语言:javascript
复制
x^2(y^2 + y + 1) + 1

通过奇偶推理,这个公式只涉及6*N运算。如果我有一个更大更复杂的公式,差别可能会相当大。

在任何情况下,我都不需要找到使性能最大化的表示,但是很明显,对术语进行分组至少可以提高性能。

我如何做这种“术语分组”,以实现一个更有效的表示我的sympy公式时,兰巴达菲?

EN

回答 2

Stack Overflow用户

发布于 2019-08-30 21:30:56

您可以按照相同的符号对术语进行分组,并在其上使用horner

代码语言:javascript
复制
>>> d=defaultdict(list)
>>> for t in Add.make_args(eq):
...  d[tuple(ordered(t.free_symbols))].append(t)
...
>>> Add(*[horner(Add(*i)) for i in d.values()])
x**2*y*(y + 1) + x**2 + 1
票数 2
EN

Stack Overflow用户

发布于 2019-09-02 14:59:12

最后我使用了sympy.collect。如果方程没有太多的变量,那么就可以简单地强求所有的组合,并恢复到“收集”项中。

这是我想出的密码。可能还有很多改进的余地:

代码语言:javascript
复制
def collect_best(expr, measure=sympy.count_ops):
    # This method performs sympy.collect over all permutations of the free variables, and returns the best collection
    best = expr
    best_score = measure(expr)
    perms = itertools.permutations(expr.free_symbols)
    permlen = np.math.factorial(len(expr.free_symbols))
    print(permlen)
    for i, perm in enumerate(perms):
        if (permlen > 1000) and not (i%int(permlen/100)):
            print(i)
        collected = sympy.collect(expr, perm)
        if measure(collected) < best_score:
            best_score = measure(collected)
            best = collected
    return best

def product(args):
    arg = next(args)
    try:
        return arg*product(args)
    except:
        return arg

def rcollect_best(expr, measure=sympy.count_ops):
    # This method performs collect_best recursively on the collected terms
    best = collect_best(expr, measure)
    best_score = measure(best)
    if expr == best:
        return best
    if isinstance(best, sympy.Mul):
        return product(map(rcollect_best, best.args))
    if isinstance(best, sympy.Add):
        return sum(map(rcollect_best, best.args))

rcollect_best将此转换为(count_ops = 136):

代码语言:javascript
复制
4*a**3*d*e - 6*a**2*b*d*e - 6*a**2*c*d*e + 16*a**2*e**3 + 6*a**2*e*f**2 + 6*a**2*e*g**2 + 2*a*b**2*d*e + 8*a*b*c*d*e - 14*a*b*e**3 - 2*a*b*e*f**2 - 8*a*b*e*g**2 + 2*a*c**2*d*e - 14*a*c*e**3 - 8*a*c*e*f**2 - 2*a*c*e*g**2 - 2*b**2*c*d*e + 2*b**2*e**3 + 2*b**2*e*g**2 - 2*b*c**2*d*e + 8*b*c*e**3 + 2*b*c*e*f**2 + 2*b*c*e*g**2 + 2*c**2*e**3 + 2*c**2*e*f**2

其中(count_ops = 68):

代码语言:javascript
复制
2*e*(d*(2*a**3 - 3*a**2*b + a*b**2 + c**2*(a - b) + c*(-3*a**2 + 4*a*b - b**2)) + e**2*(8*a**2 - 7*a*b + b**2 + c**2 + c*(-7*a + 4*b)) + f**2*(3*a**2 - a*b + c**2 + c*(-4*a + b)) + g**2*(3*a**2 - 4*a*b + b**2 + c*(-a + b)))

这是7个变量中的一个5次多项式。运行时间大约是10到15分钟,并且会以指数级增长,所以我不推荐比这要求更高的东西。我确信有一些基本的改进可以解决超指数增长,但这已经解决了我的问题,所以我现在正在兑现。:)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57732417

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档