强制零截距的线性回归

时间:2018-01-01 15:45:58

标签: python numpy linear-regression

我使用下面的最小二乘法计算系数:

#Estimate coefficients of linear equation y = a + b*x
    def calc_coefficients(_x, _y):
        x, y = np.mean(_x), np.mean(_y)
        xy = np.mean(_x*_y)
        x2, y2 = np.mean(_x**2), np.mean(_y**2)
        n = len(_x)

        b = (xy - x*y) / (x2 - x**2)
        a = y - b*x
        sig_b = np.sqrt((y2-y**2)/(x2-x**2)-b**2) / np.sqrt(n)
        sig_a = sig_b * np.sqrt(x2 - x**2)

        return a, b, sig_a, sig_b

示例数据:

_x= [(0.009412743,0.014965211,0.013263312,0.013529132,0.009989368,0.013932615,0.020849682,0.010953529,0.003608903,0.007220992,0.012750529,0.021608436,0.031742052,0.022482958,0.021137599,0.018703295,0.021633681,0.019866029,0.020260629,0.034433715,0.009241074,0.012027059)]

_y = 0.294158677,0.359935335,0.313484808,0.301917271,0.169190763,0.486254864,0.305846328,0.347077387,0.188928817,0.422194367,0.41157232,0.39281496,0.497935681,0.34763333,0.281712023,0.352045535,0.339958296,0.395932086,0.359905526,0.450004349,0.395200865,0.365162443)]

但是,我需要a(y-intercept)为零。 (y = bx)。 我尝试过使用:

np.linalg.lstsq(_x, _y)

但是我收到了这个错误:

LinAlgError: 1-dimensional array given. Array must be two-dimensional

适合y = bx数据的最佳方法是什么?

1 个答案:

答案 0 :(得分:3)

错误是因为你传递了一个一维数组,它应该是一个二维的形状数组(n,1) - 所以,一个矩阵有一列。你可以做x.reshape(-1, 1)但是这里有一种方法可以用x的任意一组度数进行最小二乘拟合:

import numpy as np
x = np.array([0, 1, 2, 3, 4, 5])
y = np.array([3, 6, 5, 7, 9, 1])
degrees = [1]       # list of degrees of x to use
matrix = np.stack([x**d for d in degrees], axis=-1)   # stack them like columns
coeff = np.linalg.lstsq(matrix, y)[0]    # lstsq returns some additional info we ignore
print("Coefficients", coeff)
fit = np.dot(matrix, coeff)
print("Fitted curve/line", fit)

传递给lstsq的矩阵应该包含f(x)形式的列,其中f贯穿模型中允许的术语。因此,如果它是一般线性模型,您将拥有x**0x**1。在强制零拦截的情况下,它只是x**1。一般来说,这些也不一定是x的幂。

度数= [1]的输出,模型y = bx

Coefficients [ 1.41818182]
Fitted curve/line [ 0.          1.41818182  2.83636364  4.25454545  5.67272727  7.09090909]

度数= [0,1]的输出,模型y = a + bx

Coefficients [ 5.0952381   0.02857143]
Fitted curve/line [ 5.0952381   5.12380952  5.15238095  5.18095238  5.20952381  5.23809524]