我使用了问题How to apply piecewise linear fit in Python?中的一些代码,用单个断点执行分段线性逼近。
代码如下:
from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03])
def piecewise_linear(x, x0, y0, k1, k2):
return np.piecewise(x,
[x < x0],
[lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])
p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 15, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))
我试图找出如何扩展它以处理n个断点。
我为piecewise_linear()方法尝试了以下代码来处理2个断点,但它不会以任何方式改变断点的值。
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03, 150, 152, 154, 156, 158])
def piecewise_linear(x, x0, x1, a1, b1, a2, b2, a3, b3):
return np.piecewise(x,
[x < x0, np.logical_and(x >= x0, x < x1), x >= x1 ],
[lambda x:a1*x + b1, lambda x:a2*x+b2, lambda x: a3*x + b3])
p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 20, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))
非常感谢任何输入
答案 0 :(得分:6)
NumPy有一个polyfit
function,可以很容易地通过一组点找到最合适的线:
coefs = npoly.polyfit(xi, yi, 1)
所以真的唯一的困难就是找到断点。对于给定的一组 断点通过给定的数据找到最合适的线条是微不足道的。
因此,不要试图找到断点和系数的位置 在线性部分的同时,足以使参数空间最小化 断点。
由于断点可以通过整数索引值指定到x
数组中,
参数空间可以被认为是N
维的整数网格上的点,其中
N
是断点的数量。
optimize.curve_fit
不是这个问题的最小化器的好选择
因为参数空间是整数值。如果您使用curve_fit
,
算法会调整参数以确定在哪个方向上
移动。如果调整小于1个单位,则断点的x值不会
更改,因此错误不会更改,因此算法不会获得任何信息
关于移动参数的正确方向。因此curve_fit
当参数空间基本上是整数值时,它往往会失败。
一个更好但不是非常快的最小化器将是一个强力网格搜索。如果
断点的数量很小(x
的参数空间 - 值是
小)这可能就足够了。如果断点的数量很大和/或
参数空间很大,那么也许设置一个多级粗/细
(蛮力)网格搜索。或者,也许有人会建议使用比蛮力更聪明的最小化器......
import numpy as np
import numpy.polynomial.polynomial as npoly
from scipy import optimize
import matplotlib.pyplot as plt
np.random.seed(2017)
def f(breakpoints, x, y, fcache):
breakpoints = tuple(map(int, sorted(breakpoints)))
if breakpoints not in fcache:
total_error = 0
for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y):
total_error += ((f(xi) - yi)**2).sum()
fcache[breakpoints] = total_error
# print('{} --> {}'.format(breakpoints, fcache[breakpoints]))
return fcache[breakpoints]
def find_best_piecewise_polynomial(breakpoints, x, y):
breakpoints = tuple(map(int, sorted(breakpoints)))
xs = np.split(x, breakpoints)
ys = np.split(y, breakpoints)
result = []
for xi, yi in zip(xs, ys):
if len(xi) < 2: continue
coefs = npoly.polyfit(xi, yi, 1)
f = npoly.Polynomial(coefs)
result.append([f, xi, yi])
return result
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25,
126.14, 140.03, 150, 152, 154, 156, 158])
# Add some noise to make it exciting :)
y += np.random.random(len(y))*10
num_breakpoints = 2
breakpoints = optimize.brute(
f, [slice(1, len(x), 1)]*num_breakpoints, args=(x, y, {}), finish=None)
plt.scatter(x, y, c='blue', s=50)
for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y):
x_interval = np.array([xi.min(), xi.max()])
print('y = {:35s}, if x in [{}, {}]'.format(str(f), *x_interval))
plt.plot(x_interval, f(x_interval), 'ro-')
plt.show()
打印
y = poly([ 4.58801083 2.94476604]) , if x in [1.0, 6.0]
y = poly([-70.36472935 14.37305793]) , if x in [7.0, 15.0]
y = poly([ 123.24565235 1.94982153]), if x in [16.0, 20.0]
和情节