lambdify包含UndefinedFunction的衍生物的sympy表达式

时间:2015-04-28 13:15:31

标签: python python-3.x sympy

我有几个未定义函数的表达式,其中一些包含该函数的相应(未定义)派生词。函数及其衍生物都只作为数值数据存在。我想用我的表达式创建函数,然后用相应的数值数据调用该函数来数值计算表达式。不幸的是我遇到了lambdify的问题。

考虑以下简化示例:

import sympy
import numpy

# define a parameter and an unknown function on said parameter
t = sympy.Symbol('t')
s = sympy.Function('s')(t)

# a "normal" expression
a = t*s**2
print(a)
#OUT: t*s(t)**2

# an expression which contains a derivative
b = a.diff(t)
print(b)
#OUT: 2*t*s(t)*Derivative(s(t), t) + s(t)**2

# generate an arbitrary numerical input
# for demo purposes lets assume that s(t):=sin(t)
t0 = 0
s0 = numpy.sin(t0)
sd0 = numpy.cos(t0)

# labdify a
fa = sympy.lambdify([t, s], a)
va = fa(t0, s0)
print (va)
#OUT: 0

# try to lambdify b
fb = sympy.lambdify([t, s, s.diff(t)], b)  # this fails with syntax error
vb = fb(t0, s0, sd0)
print (vb)

错误讯息:

  File "<string>", line 1
    lambda _Dummy_142,_Dummy_143,Derivative(s(t), t): (2*_Dummy_142*_Dummy_143*Derivative(_Dummy_143, _Dummy_142) + _Dummy_143**2)
                                           ^
SyntaxError: invalid syntax

显然Derivative对象未正确解析,我该如何解决?

作为lambdify的替代品,我也愿意使用基于aano或cython的解决方案,但我遇到了与相应打印机类似的问题。

感谢任何帮助。

3 个答案:

答案 0 :(得分:3)

据我所知,问题源于lambdify函数中的错误/不幸的虚假化过程。我已经编写了自己的dummification函数,我将其应用于参数以及表达式,然后再将它们传递给lambdifying。

def dummify_undefined_functions(expr):
    mapping = {}    

    # replace all Derivative terms
    for der in expr.atoms(sympy.Derivative):
        f_name = der.expr.func.__name__
        var_names = [var.name for var in der.variables]
        name = "d%s_d%s" % (f_name, 'd'.join(var_names))
        mapping[der] = sympy.Symbol(name)

    # replace undefined functions
    from sympy.core.function import AppliedUndef
    for f in expr.atoms(AppliedUndef):
        f_name = f.func.__name__
        mapping[f] = sympy.Symbol(f_name)

    return expr.subs(mapping)

像这样使用:

params = [dummify_undefined_functions(x) for x in [t, s, s.diff(t)]]
expr = dummify_undefined_functions(b)
fb = sympy.lambdify(params, expr)

显然这有些脆弱:

  • 无法防止名字冲突
  • 也许不是最好的名称方案:Derivative(f(x,y), x, y)
  • 的df_dxdy
  • 假设所有衍生品都具有以下形式: Derivative(s(t), t, ...) s(t)UndefinedFunctiontSymbol {{1}}。我不知道如果Derivative的任何参数是一个更复杂的表达式会发生什么。我有点想/希望(自动)简化过程将任何更复杂的衍生物减少为由“基本”衍生物组成的表达式。但我当然不会反对它。
  • 很大程度上未经测试(除了我的特定用例)

除此之外,它的效果非常好。

答案 1 :(得分:0)

讨论了类似的问题at here

您只需要定义自己的函数并将其派生词定义为另一个函数:

def f_impl(x):
    return x**2

def df_impl(x):
    return 2*x

class df(sy.Function):
    nargs = 1
    is_real = True
    _imp_ = staticmethod(df_impl)

class f(sy.Function):
    nargs = 1
    is_real = True
    _imp_ = staticmethod(f_impl)

    def fdiff(self, argindex=1):
        return df(self.args[0])

t = sy.Symbol('t')
print f(t).diff().subs({t:0.1})

expr = f(t) + f(t).diff()
expr_real = sy.lambdify(t, expr)
print expr_real(0.1)

答案 2 :(得分:0)

首先,不是使用UndefinedFunction,您可以继续使用并使用Implemented_function函数将s(t)的数字实现与符号函数联系起来。

然后,如果您受限于定义函数的离散数值数据,该函数的导数出现在麻烦的表达式中,那么很多时候,导数的数值计算可能来自有限差分。作为替代,sympy可以自动用有限差分替换导数项,并将结果表达式转换为lambda。例如:

import sympy
import numpy
from sympy.utilities.lambdify import lambdify, implemented_function
from sympy import Function

# define a parameter and an unknown function on said parameter
t = sympy.Symbol('t')
s = implemented_function(Function('s'), numpy.cos)

print('A plain ol\' expression')
a = t*s(t)**2
print(a)

print('Derivative of above:')
b = a.diff(t)
print(b)

# try to lambdify b by first replacing with finite differences
dx = 0.1
bapprox = b.replace(lambda arg: arg.is_Derivative,
                    lambda arg: arg.as_finite_difference(points=dx))
print('Approximation of derivatives:')
print(bapprox)
fb = sympy.lambdify([t], bapprox)

t0 = 0.0
vb = fb(t0)
print(vb)