具有n个断点的分段线性拟合

时间:2017-09-14 12:18:22

标签: python numpy scipy curve-fitting piecewise

我使用了问题How to apply piecewise linear fit in Python?中的一些代码,用单个断点执行分段线性逼近。

代码如下:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03])

def piecewise_linear(x, x0, y0, k1, k2):
    return np.piecewise(x, 
                       [x < x0], 
                       [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 15, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))

我试图找出如何扩展它以处理n个断点。

我为piecewise_linear()方法尝试了以下代码来处理2个断点,但它不会以任何方式改变断点的值。

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03, 150, 152, 154, 156, 158])

def piecewise_linear(x, x0, x1, a1, b1, a2, b2, a3, b3):
    return np.piecewise(x,
                       [x < x0, np.logical_and(x >= x0, x < x1), x >= x1 ], 
                       [lambda x:a1*x + b1, lambda x:a2*x+b2, lambda x: a3*x + b3])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 20, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))

非常感谢任何输入

1 个答案:

答案 0 :(得分:6)

NumPy有一个polyfit function,可以很容易地通过一组点找到最合适的线:

coefs = npoly.polyfit(xi, yi, 1)

所以真的唯一的困难就是找到断点。对于给定的一组 断点通过给定的数据找到最合适的线条是微不足道的。

因此,不要试图找到断点系数的位置 在线性部分的同时,足以使参数空间最小化 断点。

由于断点可以通过整数索引值指定到x数组中, 参数空间可以被认为是N维的整数网格上的点,其中 N是断点的数量。

optimize.curve_fit不是这个问题的最小化器的好选择 因为参数空间是整数值。如果您使用curve_fit, 算法会调整参数以确定在哪个方向上 移动。如果调整小于1个单位,则断点的x值不会 更改,因此错误不会更改,因此算法不会获得任何信息 关于移动参数的正确方向。因此curve_fit 当参数空间基本上是整数值时,它往往会失败。

一个更好但不是非常快的最小化器将是一个强力网格搜索。如果 断点的数量很小(x的参数空间 - 值是 小)这可能就足够了。如果断点的数量很大和/或 参数空间很大,那么也许设置一个多级粗/细 (蛮力)网格搜索。或者,也许有人会建议使用比蛮力更聪明的最小化器......

import numpy as np
import numpy.polynomial.polynomial as npoly
from scipy import optimize
import matplotlib.pyplot as plt
np.random.seed(2017)

def f(breakpoints, x, y, fcache):
    breakpoints = tuple(map(int, sorted(breakpoints)))
    if breakpoints not in fcache:
        total_error = 0
        for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y):
            total_error += ((f(xi) - yi)**2).sum()
        fcache[breakpoints] = total_error
    # print('{} --> {}'.format(breakpoints, fcache[breakpoints]))
    return fcache[breakpoints]

def find_best_piecewise_polynomial(breakpoints, x, y):
    breakpoints = tuple(map(int, sorted(breakpoints)))
    xs = np.split(x, breakpoints)
    ys = np.split(y, breakpoints)
    result = []
    for xi, yi in zip(xs, ys):
        if len(xi) < 2: continue
        coefs = npoly.polyfit(xi, yi, 1)
        f = npoly.Polynomial(coefs)
        result.append([f, xi, yi])
    return result

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 
              18, 19, 20], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 
              126.14, 140.03, 150, 152, 154, 156, 158])
# Add some noise to make it exciting :)
y += np.random.random(len(y))*10

num_breakpoints = 2
breakpoints = optimize.brute(
    f, [slice(1, len(x), 1)]*num_breakpoints, args=(x, y, {}), finish=None)

plt.scatter(x, y, c='blue', s=50)
for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y):
    x_interval = np.array([xi.min(), xi.max()])
    print('y = {:35s}, if x in [{}, {}]'.format(str(f), *x_interval))
    plt.plot(x_interval, f(x_interval), 'ro-')


plt.show()

打印

y = poly([ 4.58801083  2.94476604])    , if x in [1.0, 6.0]
y = poly([-70.36472935  14.37305793])  , if x in [7.0, 15.0]
y = poly([ 123.24565235    1.94982153]), if x in [16.0, 20.0]

和情节

enter image description here

相关问题