如何制作任意长度的numpy.piecewise函数? (有lambda问题)

时间:2014-05-30 21:00:08

标签: python numpy lambda piecewise

我试图将分段拟合绘制到我的数据中,但我需要使用任意数量的线段进行绘制。有时有三个部分;有时有两个。我将拟合系数存储在有效状态中,并将这些系数存储在btable中的分段上。

以下是我的界限的示例值:

btable = [[0.00499999989, 0.0244274978], [0.0244275965, 0.0599999987]]

以下是我的系数的示例值:

actable = [[0.0108687987, -0.673182865, 14.6420775], [0.00410866373, -0.0588355861, 1.07750032]]

这是我的代码:

rfig = plt.figure()
<>various other plot specifications<>
x = np.arange(0.005, 0.06, 0.0001)
y = np.piecewise(x, [(x >= btable[i][0]) & (x <= btable[i][1]) for i in range(len(btable))], [lambda x=x: np.log10(actable[j][0] + actable[j][2] * x + actable[j][2] * x**2) for j in list(range(len(actable)))])
plt.plot(x, y)

问题是lambda将自己设置为列表的最后一个实例,因此它使用所有段的最后一个段的系数。我不知道如何在不使用lambda的情况下进行分段函数。

目前,我这样做是在作弊:

if len(btable) == 2:
    y = np.piecewise(x, [(x >= btable[i][0]) & (x <= btable[i][1]) for i in range(len(btable))], [lambda x: np.log10(actable[0][0] + actable[0][1] * x + actable[0][2] * x**2), lambda x: np.log10(actable[1][0] + actable[1][1] * x + actable[1][2] * x**2)])
else if len(btable) == 3:
    y = np.piecewise(x, [(x >= btable[i][0]) & (x <= btable[i][1]) for i in range(len(btable))], [lambda x: np.log10(actable[0][0] + actable[0][1] * x + actable[0][2] * x**2), lambda x: np.log10(actable[1][0] + actable[1][1] * x + actable[1][2] * x**2), lambda x: np.log10(actable[2][0] + actable[2][1] * x + actable[2][2] * x**2)])
else
    print('Oh no!  You have fewer than 2 or more than 3 segments!')

但是这让我觉得内心很蠢。我知道必须有更好的解决方案。有人可以帮忙吗?

1 个答案:

答案 0 :(得分:0)

这个问题很常见,Python官方文档中有一篇文章Why do lambdas defined in a loop with different values all return the same result?,其中包含一个建议的解决方案:创建一个由循环变量初始化的局部变量,以捕获函数中后者的变化值。

也就是说,在y的定义中,它足以替换

[lambda x=x: np.log10(actable[j][0] + actable[j][1] * x + actable[j][2] * x**2) for j in range(len(actable))]

通过

[lambda x=x, k=j: np.log10(actable[k][0] + actable[k][1] * x + actable[k][2] * x**2) for j in range(len(actable))]

顺便说一下,可以使用单侧不等式来指定numpy.piecewise的范围:求值为True的条件的 last 将触发相应的函数。 (这有点违反直觉;使用第一个真实条件会更自然,就像SymPy那样)。如果断点按递增顺序排列,则应使用“x&gt; =”不等式:

breaks = np.arange(0, 10)       # breakpoints
coeff = np.arange(0, 20, 2)     # coefficients to use
x = np.arange(0, 10, 0.1)
y = np.piecewise(x, [x >= b for b in breaks], [lambda x=x, a=c: a*x for c in coeff])

这里每个系数将用于开始的间隔和相应的断点;例如,系数c = 0用于范围0<=x<1,系数c = 2,范围为1<=x<2,依此类推。

相关问题