定义一个返回(插值)函数的函数

时间:2018-01-15 19:03:35

标签: python interpolation currying

我尝试编写一个返回插值函数的函数,该插值函数在插值范围之外进行线性推断。在查看关于currying的帖子后,我无法弄清楚为什么我没有工作。我有:

def interpolation(X_list,a_list): 
    A1=scipy.interpolate.UnivariateSpline(
        np.asarray(X_list),
        np.asarray(a_list),
        k=3,
        s=0,
        check_finite=True)

    m=(((a_list[-1])-(a_list[-2]))
       / ((X_list[-1])-(X_list[-2])))

    A1ext= m*X+a_list[-1]-m*X_list[-1]

    def a(X):
        if X_list[-1]>=X:
            return A1
        if X>X_list[-1]:
            return A1ext
    return a(X)

1 个答案:

答案 0 :(得分:0)

只需返回内部函数

def interpolation(X_list,a_list): 
    A1=scipy.interpolate.UnivariateSpline(
        np.asarray(X_list),
        np.asarray(a_list),
        k=3,
        s=0,
        check_finite=True)

    m=(((a_list[-1])-(a_list[-2]))
       / ((X_list[-1])-(X_list[-2])))

    A1ext= m*X+a_list[-1]-m*X_list[-1]

    def a(X):
        if X_list[-1]>=X:
            return A1
        if X>X_list[-1]:
            return A1ext
    return a

if __name__ == '__main__':
    a = interpolation(X_list,a_list)
    a(X)

为了缩短它,你可以用lambda匿名函数替换那个内部函数(a):

lambda X: A1 if X_list[-1] >= X else A1ext

现在它将是:

def interpolation(X_list,a_list): 
    A1=scipy.interpolate.UnivariateSpline(
        np.asarray(X_list),
        np.asarray(a_list),
        k=3,
        s=0,
        check_finite=True)

    m=(((a_list[-1])-(a_list[-2]))
       / ((X_list[-1])-(X_list[-2])))

    A1ext= m*X+a_list[-1]-m*X_list[-1]

    return lambda X: A1 if X_list[-1] >= X else A1ext