【机器学习】正则化的线性回归 —— 岭回归与Lasso回归 (2)

如果想象一种非常极端的情况:在参数的整个定义域上,第2项的取值都远远大于第一项的取值,那么最终的损失函数几乎100%都会由第2项决定,也就是整个代价函数的图像会非常类似于$J=|\theta_1|$(图0-2)而不是原来的MSE函数的图像(图0-1)。这时候就相当于$\lambda$的取值过大的情况,最终的全局最优解将会是坐标原点,这就是为什么在这种情况下最终得到的解全都为0.

 

1. 岭回归

岭回归与多项式回归唯一的不同在于代价函数上的差别。岭回归的代价函数如下:

$$J(\theta) = \frac{1}{m} \sum_{i=1}^{m}{(y^{(i)} - (w x^{(i)} + b))^2}  + \lambda ||w||_2^2 = MSE(\theta) + \lambda \sum_{i = 1}^{n}{\theta_i^2} \ \quad \cdots \ (1 - 1)$$

为了方便计算导数,通常也写成下面的形式:

$$J(\theta) = \frac{1}{2m} \sum_{i=1}^{m}{(y^{(i)} - (w x^{(i)} + b))^2}  + \frac{\lambda}{2} ||w||_2^2 = \frac{1}{2}MSE(\theta) + \frac{\lambda}{2} \sum_{i = 1}^{n}{\theta_i^2} \ \quad \cdots \ (1 - 2)$$

上式中的$w$是长度为$n$的向量,不包括截距项的系数$\theta_0$, $\theta$是长度为$n + 1$的向量,包括截距项的系数$\theta_0$,$m$为样本数,$n$为特征数.

 

岭回归的代价函数仍然是一个凸函数,因此可以利用梯度等于0的方式求得全局最优解(正规方程):

$$\theta = (X^T X + \lambda I)^{-1}(X^T y)$$

上述正规方程与一般线性回归的正规方程相比,多了一项$\lambda I$,其中$I$表示单位矩阵。假如$X^T X$是一个奇异矩阵(不满秩),添加这一项后可以保证该项可逆。由于单位矩阵的形状是对角线上为1其他地方都为0,看起来像一条山岭,因此而得名。

 

除了上述正规方程之外,还可以使用梯度下降的方式求解(求梯度的过程可以参考一般线性回归,3.2.2节)。这里采用式子$1 - 2$来求导:

$$\nabla_{\theta} J(\theta) = \frac{1}{m} X^T \cdot (X \cdot \theta - y)  + \lambda w \ \quad \cdots \ (1 - 3) $$

因为式子$1- 2$中和式第二项不包含$\theta_0$,因此求梯度后,上式第二项中的$w$本来也不包含$\theta_0$。为了计算方便,添加$\theta_0 = 0$到$w$.

因此在梯度下降的过程中,参数的更新可以表示成下面的公式:

$$\theta = \theta - (\frac{\alpha}{m} X^T \cdot (X \cdot \theta - y)  + \lambda w) \ \quad \cdots \ (1 - 4) $$

其中$\alpha$为学习率,$\lambda$为正则化项的参数

1.1 数据以及相关函数

1 import numpy as np 2 import matplotlib.pyplot as plt 3 from sklearn.preprocessing import PolynomialFeatures 4 from sklearn.metrics import mean_squared_error 5 6 data = np.array([[ -2.95507616, 10.94533252], 7 [ -0.44226119, 2.96705822], 8 [ -2.13294087, 6.57336839], 9 [ 1.84990823, 5.44244467], 10 [ 0.35139795, 2.83533936], 11 [ -1.77443098, 5.6800407 ], 12 [ -1.8657203 , 6.34470814], 13 [ 1.61526823, 4.77833358], 14 [ -2.38043687, 8.51887713], 15 [ -1.40513866, 4.18262786]]) 16 m = data.shape[0] # 样本大小 17 X = data[:, 0].reshape(-1, 1) # 将array转换成矩阵 18 y = data[:, 1].reshape(-1, 1)

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wpzxsf.html