Julia中的坐标下降算法用于最小二乘不收敛

时间:2016-12-16 19:04:28

标签: optimization regression julia numerical-methods convergence

作为编写自己的弹性网络解算器的热身,我试图使用坐标下降来实现足够快的普通最小二乘法。

我相信我已经正确地实现了坐标下降算法,但是当我使用" fast"版本(见下文),算法非常不稳定,输出回归系数,当特征数量与样本数量相比时,通常会溢出64位浮点数。

线性回归和OLS

如果b = A * x,其中A是矩阵,x是未知回归系数的向量,y是输出,我想找到最小化的x

|| b - Ax || ^ 2

如果A [j]是A的第j列,而A [-j]是没有列j的A,并且A的列被归一化,所以对于所有j,|| A [j] || ^ 2 = 1 ,那么坐标更新是

坐标下降:

x[j]  <--  A[j]^T * (b - A[-j] * x[-j])

我跟随these notes (page 9-10),但推导是简单的微积分。

它指出,不是一直重新计算A [j] ^ T(b - A [-j] * x [-j]),更快的方法是用

快速坐标下降:

x[j]  <--  A[j]^T*r + x[j]

其中总残差r = b-Ax是在环路坐标之外计算的。这些更新规则的等效性来自注意到Ax = A [j] * x [j] + A [-j] * x [-j]并重新排列术语。

我的问题是,虽然第二种方法确实更快,但只要特征数量与样本数量相比不小,它就会在数字上非常不稳定。我想知道是否有人可能会了解为什么会出现这种情况。我应该注意到,第一种更稳定的方法仍然不同意更多的标准方法,因为特征数量接近样本数量。

朱莉亚代码

以下是两个更新规则的一些Julia代码:

function OLS_builtin(A,b)
    x = A\b
    return(x)
end

function OLS_coord_descent(A,b)    
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        for j = 1:P 
            x[j] = dot(A[:,j], b - A[:,1:P .!= j]*x[1:P .!= j])
        end    
    end
    return(x)
end

function OLS_coord_descent_fast(A,b) 
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        r = b - A*x
        for j = 1:P
            x[j] += dot(A[:,j],r)
        end    
    end
    return(x)
end

问题示例

我使用以下内容生成数据:

n = 100
p = 50
σ = 0.1
β_nz = float([i*(-1)^i for i in 1:10])

β = append!(β_nz,zeros(Float64,p-length(β_nz)))
X = randn(n,p); X .-= mean(X,1); X ./= sqrt(sum(abs2(X),1))
y = X*β + σ*randn(n); y .-= mean(y);

这里我使用p = 50,我在OLS_coord_descent(X,y)OLS_builtin(X,y)之间得到了很好的一致,而OLS_coord_descent_fast(X,y)则返回了回归系数的指数大值。

当p小于约20时,OLS_coord_descent_fast(X,y)与其他两个一致。

猜想

因为事情同意p&lt;&lt; n,我认为算法是正式的,但在数值上不稳定。有没有人对这个猜测是否正确有任何想法,如果有的话如何纠正不稳定性同时保留算法快速版本的(大部分)性能提升?

1 个答案:

答案 0 :(得分:5)

快速回答:您在每次r更新后忘记更新x[j]。以下是固定功能,其行为类似于OLS_coord_descent

function OLS_coord_descent_fast(A,b) 
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        r = b - A*x
        for j = 1:P
            x[j] += dot(A[:,j],r)
            r -= A[:,j]*dot(A[:,j],r)   # Add this line
        end    
    end
    return(x)
end
相关问题