scikit-learn对于线性回归提供了比较多的类库,这些类库都可以用来做线性回归分析,本文就对这些类库的使用做一个总结,重点讲述这些线性回归算法库的不同和各自的使用场景。
线性回归的目的是要得到输出向量\(\mathbf{Y}\)和输入特征\(\mathbf{X}\)之间的线性关系,求出线性回归系数\(\mathbf\theta\),也就是 \(\mathbf{Y = X\theta}\)。其中\(\mathbf{Y}\)的维度为mx1,\(\mathbf{X}\)的维度为mxn,而\(\mathbf{\theta}\)的维度为nx1。m代表样本个数,n代表样本特征的维度。
为了得到线性回归系数\(\mathbf{\theta}\),我们需要定义一个损失函数,一个极小化损失函数的优化方法,以及一个验证算法的方法。损失函数的不同,损失函数的优化方法的不同,验证方法的不同,就形成了不同的线性回归算法。scikit-learn中的线性回归算法库可以从这这三点找出各自的不同点。理解了这些不同点,对不同的算法使用场景也就好理解了。
1. LinearRegression损失函数:
LinearRegression类就是我们平时说的最常见普通的线性回归,它的损失函数也是最简单的,如下:
\(J(\mathbf\theta) = \frac{1}{2}(\mathbf{X\theta} - \mathbf{Y})^T(\mathbf{X\theta} - \mathbf{Y})\)
损失函数的优化方法:
对于这个损失函数,一般有梯度下降法和最小二乘法两种极小化损失函数的优化方法,而scikit中的LinearRegression类用的是最小二乘法。通过最小二乘法,可以解出线性回归系数\(\mathbf\theta\)为:
\( \mathbf{\theta} = (\mathbf{X^{T}X})^{-1}\mathbf{X^{T}Y} \)
验证方法:
LinearRegression类并没有用到交叉验证之类的验证方法,需要我们自己把数据集分成训练集和测试集,然后训练优化。
使用场景:
一般来说,只要我们觉得数据有线性关系,LinearRegression类是我们的首先。如果发现拟合或者预测的不好,再考虑用其他的线性回归库。如果是学习线性回归,推荐先从这个类开始第一步的研究。
2. Ridge损失函数:
由于第一节的LinearRegression没有考虑过拟合的问题,有可能泛化能力较差,这时损失函数可以加入正则化项,如果加入的是L2范数的正则化项,这就是Ridge回归。损失函数如下:
\(J(\mathbf\theta) = \frac{1}{2}(\mathbf{X\theta} - \mathbf{Y})^T(\mathbf{X\theta} - \mathbf{Y}) + \frac{1}{2}\alpha||\theta||_2^2\)
其中\(\alpha\)为常数系数,需要进行调优。\(||\theta||_2\)为L2范数。
Ridge回归在不抛弃任何一个特征的情况下,缩小了回归系数,使得模型相对而言比较的稳定,不至于过拟合。
损失函数的优化方法:
对于这个损失函数,一般有梯度下降法和最小二乘法两种极小化损失函数的优化方法,而scikit中的Ridge类用的是最小二乘法。通过最小二乘法,可以解出线性回归系数\(\mathbf\theta\)为:
\(\mathbf{\theta = (X^TX + \alpha E)^{-1}X^TY}\)
其中E为单位矩阵。
验证方法:
Ridge类并没有用到交叉验证之类的验证方法,需要我们自己把数据集分成训练集和测试集,需要自己设置好超参数\(\alpha\)。然后训练优化。
使用场景:
一般来说,只要我们觉得数据有线性关系,用LinearRegression类拟合的不是特别好,需要正则化,可以考虑用Ridge类。但是这个类最大的缺点是每次我们要自己指定一个超参数\(\alpha\),然后自己评估\(\alpha\)的好坏,比较麻烦,一般我都用下一节讲到的RidgeCV类来跑Ridge回归,不推荐直接用这个Ridge类,除非你只是为了学习Ridge回归。
3. RidgeCVRidgeCV类的损失函数和损失函数的优化方法完全与Ridge类相同,区别在于验证方法。
验证方法:
RidgeCV类对超参数\(\alpha\)使用了交叉验证,来帮忙我们选择一个合适的\(\alpha\)。在初始化RidgeCV类时候,我们可以传一组备选的\(\alpha\)值,10个,100个都可以。RidgeCV类会帮我们选择一个合适的\(\alpha\)。免去了我们自己去一轮轮筛选\(\alpha\)的苦恼。
使用场景:
一般来说,只要我们觉得数据有线性关系,用LinearRegression类拟合的不是特别好,需要正则化,可以考虑用RidgeCV类。不是为了学习的话就不用Ridge类。为什么这里只是考虑用RidgeCV类呢?因为线性回归正则化有很多的变种,Ridge只是其中的一种。所以可能需要比选。如果输入特征的维度很高,而且是稀疏线性关系的话,RidgeCV类就不合适了。这时应该主要考虑下面几节要讲到的Lasso回归类家族。
4. Lasso损失函数: