使用sympy消除常见的子表达式

时间:2017-04-16 22:45:41

标签: python sympy codegen

我正在尝试使用sympy来优化C中数学表达式的数值计算。一方面我知道sympy可以生成C代码来评估一个表达式如下:

from mpmath import *
from sympy.utilities.codegen import codegen
from sympy  import *

x,y,z = symbols('x y z')
[(c_name, c_code), (h_name, c_header)] = codegen([('x', sin(x))], 'C')

然后您可以将c_code打印到目标文件。另一方面,我知道cse可用于简化表达式,如下所示:

from mpmath import *
from sympy.utilities.codegen import codegen
from sympy  import *

x,y,z, B1, B2, B3, B4 = symbols('x y z B1 B2 B3 B4 ')
cse([3.0*B2 + 8.0*B3*x**2 + 3.0*B3*x*y + 4.0*B3*x*z + B3*y**2 + B3*z**2 + B4*x**4 + B4*x**3*y + B4*x**3*z + B4*x**2*y**2 + B4*x**2*y*z + B4*x**2*z**2, 7.0*B3*x*y + 2*B3*x*z + B3*(x**2 + y**2) + B4*x**3*y + B4*x**2*y**2 + B4*x**2*y*z + B4*x*y**3 + B4*x*y**2*z + B4*x*y*z**2, B3*x*y + 8.0*B3*x*z + B3*(x**2 + z**2) + B4*x**3*z + B4*x**2*y*z + B4*x**2*z**2 + B4*x*y**2*z + B4*x*y*z**2 + B4*x*z**3, 3.0*B2 + B3*x**2 + 3.0*B3*x*y + B3*x*z + 8.0*B3*y**2 + 3.0*B3*y*z + B3*z**2 + B4*x**2*y**2 + B4*x*y**3 + B4*x*y**2*z + B4*y**4 + B4*y**3*z + B4*y**2*z**2, B3*x*y + 2*B3*x*z + 6.0*B3*y*z + B3*(y**2 + z**2) + B4*x**2*y*z + B4*x*y**2*z + B4*x*y*z**2 + B4*y**3*z + B4*y**2*z**2 + B4*y*z**3, 3.0*B2 + B3*x**2 + B3*x*y + 3.0*B3*x*z + B3*y**2 + 3.0*B3*y*z + 8.0*B3*z**2 + B4*x**2*z**2 + B4*x*y*z**2 + B4*x*z**3 + B4*y**2*z**2 + B4*y*z**3 + B4*z**4])

获得输出:

([(x0, z**2),
  (x1, B3*x0),
  (x2, B3*x),
  (x3, x2*y),
  (x4, 3.0*x3),
  (x5, 3.0*B2),
  (x6, y**2),
  (x7, B3*x6),
  (x8, x2*z),
  (x9, x**2),
  (x10, B3*x9),
  (x11, B4*x**3),
  (x12, x11*y),
  (x13, x11*z),
  (x14, B4*y),
  (x15, x14*x9*z),
  (x16, B4*x9),
  (x17, x16*x6),
  (x18, x0*x16),
  (x19, 2*x8),
  (x20, y**3),
  (x21, B4*x),
  (x22, x20*x21),
  (x23, x0*x21*y),
  (x24, x21*x6*z),
  (x25, z**3),
  (x26, x21*x25),
  (x27, B3*y*z),
  (x28, x10 + 3.0*x27),
  (x29, B4*x20*z),
  (x30, B4*x0*x6),
  (x31, x14*x25)],
 [B4*x**4 + x1 + 8.0*x10 + x12 + x13 + x15 + x17 + x18 + x4 + x5 + x7 + 4.0*x8,
  B3*(x6 + x9) + x12 + x15 + x17 + x19 + x22 + x23 + x24 + 7.0*x3,
  B3*(x0 + x9) + x13 + x15 + x18 + x23 + x24 + x26 + x3 + 8.0*x8,
  B4*y**4 + x1 + x17 + x22 + x24 + x28 + x29 + x30 + x4 + x5 + 8.0*x7 + x8,
  B3*(x0 + x6) + x15 + x19 + x23 + x24 + 6.0*x27 + x29 + x3 + x30 + x31,
  B4*z**4 + 8.0*x1 + x18 + x23 + x26 + x28 + x3 + x30 + x31 + x5 + x7 + 3.0*x8])

我的问题是如何正确转换C代码中的前一个结果?有时可以用来转换字符串中的简化表达式并对这些字符串进行操作,它是如何完成的?目标是在CSE之后自动生成代码的过程,以便将其应用于许多表达式。

编辑:

基于下面的答案,感谢Wrzlprmft,生成相应C代码片段的代码是:

from sympy.printing import ccode
from sympy import symbols, cse, numbered_symbols

x,y,z, B1, B2, B3, B4 = symbols('x y z B1 B2 B3 B4 ')
results = [3.0*B2 + 8.0*B3*x**2 + 3.0*B3*x*y + 4.0*B3*x*z + B3*y**2 + B3*z**2 + B4*x**4 + B4*x**3*y + B4*x**3*z + B4*x**2*y**2 + B4*x**2*y*z + B4*x**2*z**2, 7.0*B3*x*y + 2*B3*x*z + B3*(x**2 + y**2) + B4*x**3*y + B4*x**2*y**2 + B4*x**2*y*z + B4*x*y**3 + B4*x*y**2*z + B4*x*y*z**2, B3*x*y + 8.0*B3*x*z + B3*(x**2 + z**2) + B4*x**3*z + B4*x**2*y*z + B4*x**2*z**2 + B4*x*y**2*z + B4*x*y*z**2 + B4*x*z**3, 3.0*B2 + B3*x**2 + 3.0*B3*x*y + B3*x*z + 8.0*B3*y**2 + 3.0*B3*y*z + B3*z**2 + B4*x**2*y**2 + B4*x*y**3 + B4*x*y**2*z + B4*y**4 + B4*y**3*z + B4*y**2*z**2, B3*x*y + 2*B3*x*z + 6.0*B3*y*z + B3*(y**2 + z**2) + B4*x**2*y*z + B4*x*y**2*z + B4*x*y*z**2 + B4*y**3*z + B4*y**2*z**2 + B4*y*z**3, 3.0*B2 + B3*x**2 + B3*x*y + 3.0*B3*x*z + B3*y**2 + 3.0*B3*y*z + 8.0*B3*z**2 + B4*x**2*z**2 + B4*x*y*z**2 + B4*x*z**3 + B4*y**2*z**2 + B4*y*z**3 + B4*z**4]
CSE_results = cse(results,numbered_symbols("helper_"))

with open("snippet.c", "w") as output:
    for helper in CSE_results[0]:
        output.write("double ")
        output.write(ccode(helper[1],helper[0]))
        output.write("\n")

    for i,result in enumerate(CSE_results[1]):
        output.write(ccode(result,"result_%d"%i))
        output.write("\n")

1 个答案:

答案 0 :(得分:3)

最痛苦的方法是使用较低级ccode例程(然后根据需要创建周围的C样板):

from sympy.printing import ccode
from sympy import symbols, cse, numbered_symbols

x,y,z, B1, B2, B3, B4 = symbols('x y z B1 B2 B3 B4 ')
results = [3.0*B2 + 8.0*B3*x**2 + 3.0*B3*x*y + 4.0*B3*x*z + B3*y**2 + B3*z**2 + B4*x**4 + B4*x**3*y + B4*x**3*z + B4*x**2*y**2 + B4*x**2*y*z + B4*x**2*z**2, 7.0*B3*x*y + 2*B3*x*z + B3*(x**2 + y**2) + B4*x**3*y + B4*x**2*y**2 + B4*x**2*y*z + B4*x*y**3 + B4*x*y**2*z + B4*x*y*z**2, B3*x*y + 8.0*B3*x*z + B3*(x**2 + z**2) + B4*x**3*z + B4*x**2*y*z + B4*x**2*z**2 + B4*x*y**2*z + B4*x*y*z**2 + B4*x*z**3, 3.0*B2 + B3*x**2 + 3.0*B3*x*y + B3*x*z + 8.0*B3*y**2 + 3.0*B3*y*z + B3*z**2 + B4*x**2*y**2 + B4*x*y**3 + B4*x*y**2*z + B4*y**4 + B4*y**3*z + B4*y**2*z**2, B3*x*y + 2*B3*x*z + 6.0*B3*y*z + B3*(y**2 + z**2) + B4*x**2*y*z + B4*x*y**2*z + B4*x*y*z**2 + B4*y**3*z + B4*y**2*z**2 + B4*y*z**3, 3.0*B2 + B3*x**2 + B3*x*y + 3.0*B3*x*z + B3*y**2 + 3.0*B3*y*z + 8.0*B3*z**2 + B4*x**2*z**2 + B4*x*y*z**2 + B4*x*z**3 + B4*y**2*z**2 + B4*y*z**3 + B4*z**4]
CSE_results = cse(results,numbered_symbols("helper_"))

with open("snippet.c", "w") as output:
    for helper in CSE_results[0]:
        output.write("double ")
        output.write(ccode(helper[1],helper[0]))
        output.write("\n")

    for i,result in enumerate(CSE_results[1]):
        output.write(ccode(result,"result_%i"%i))
        output.write("\n")

这将生成一个如下所示的文件snippet.c

double helper_0 = pow(z, 2);
double helper_1 = B3*helper_0;
[…]
double helper_31 = helper_14*helper_25;
result_0 = B4*pow(x, 4) + helper_1 + 8.0*helper_10 + […];
[…]
result_5 = B4*pow(z, 4) + 8.0*helper_1 + helper_18 + […];
相关问题