我有一个渐近多项式,它有一个非常多的项。我想抨击那个公式。然而,由于它有一个非常多的项,并且多项式是展开的,所以有更多的运算下降,而不是最优的。具体来说,通过将某些术语组合在一起,我们可以消除一些操作。考虑以下等式,例如:
x^2y^2 + x^2y + x^2 + 1如果我对此进行抨击,那么,如果x和y是长度为N的一维np.arrays,则会有4个元素方向的平方-ings,2个按元素方向的乘法,以及3个按元素方向的加法,从而产生大约9*N的运算。
OTOH,通过做一点代数,我们得出:
x^2(y^2 + y + 1) + 1通过奇偶推理,这个公式只涉及6*N运算。如果我有一个更大更复杂的公式,差别可能会相当大。
在任何情况下,我都不需要找到使性能最大化的表示,但是很明显,对术语进行分组至少可以提高性能。
我如何做这种“术语分组”,以实现一个更有效的表示我的sympy公式时,兰巴达菲?
发布于 2019-08-30 21:30:56
您可以按照相同的符号对术语进行分组,并在其上使用horner:
>>> d=defaultdict(list)
>>> for t in Add.make_args(eq):
... d[tuple(ordered(t.free_symbols))].append(t)
...
>>> Add(*[horner(Add(*i)) for i in d.values()])
x**2*y*(y + 1) + x**2 + 1发布于 2019-09-02 14:59:12
最后我使用了sympy.collect。如果方程没有太多的变量,那么就可以简单地强求所有的组合,并恢复到“收集”项中。
这是我想出的密码。可能还有很多改进的余地:
def collect_best(expr, measure=sympy.count_ops):
# This method performs sympy.collect over all permutations of the free variables, and returns the best collection
best = expr
best_score = measure(expr)
perms = itertools.permutations(expr.free_symbols)
permlen = np.math.factorial(len(expr.free_symbols))
print(permlen)
for i, perm in enumerate(perms):
if (permlen > 1000) and not (i%int(permlen/100)):
print(i)
collected = sympy.collect(expr, perm)
if measure(collected) < best_score:
best_score = measure(collected)
best = collected
return best
def product(args):
arg = next(args)
try:
return arg*product(args)
except:
return arg
def rcollect_best(expr, measure=sympy.count_ops):
# This method performs collect_best recursively on the collected terms
best = collect_best(expr, measure)
best_score = measure(best)
if expr == best:
return best
if isinstance(best, sympy.Mul):
return product(map(rcollect_best, best.args))
if isinstance(best, sympy.Add):
return sum(map(rcollect_best, best.args))rcollect_best将此转换为(count_ops = 136):
4*a**3*d*e - 6*a**2*b*d*e - 6*a**2*c*d*e + 16*a**2*e**3 + 6*a**2*e*f**2 + 6*a**2*e*g**2 + 2*a*b**2*d*e + 8*a*b*c*d*e - 14*a*b*e**3 - 2*a*b*e*f**2 - 8*a*b*e*g**2 + 2*a*c**2*d*e - 14*a*c*e**3 - 8*a*c*e*f**2 - 2*a*c*e*g**2 - 2*b**2*c*d*e + 2*b**2*e**3 + 2*b**2*e*g**2 - 2*b*c**2*d*e + 8*b*c*e**3 + 2*b*c*e*f**2 + 2*b*c*e*g**2 + 2*c**2*e**3 + 2*c**2*e*f**2其中(count_ops = 68):
2*e*(d*(2*a**3 - 3*a**2*b + a*b**2 + c**2*(a - b) + c*(-3*a**2 + 4*a*b - b**2)) + e**2*(8*a**2 - 7*a*b + b**2 + c**2 + c*(-7*a + 4*b)) + f**2*(3*a**2 - a*b + c**2 + c*(-4*a + b)) + g**2*(3*a**2 - 4*a*b + b**2 + c*(-a + b)))这是7个变量中的一个5次多项式。运行时间大约是10到15分钟,并且会以指数级增长,所以我不推荐比这要求更高的东西。我确信有一些基本的改进可以解决超指数增长,但这已经解决了我的问题,所以我现在正在兑现。:)
https://stackoverflow.com/questions/57732417
复制相似问题