手撸机器学习算法 - 线性回归

系列文章目录:

感知机

线性回归

如果说感知机是最最最简单的分类算法,那么线性回归就是最最最简单的回归算法,所以这一篇我们就一起来快活的用两种姿势手撸线性回归吧;

算法介绍

线性回归通过超平面拟合数据点,经验误差一般使用MSE(均平方误差),优化方法为最小二乘法,算法如下:

假设输入数据为X,输出为Y,为了简单起见,这里的数据点为一维数据(更好可视化,处理方式没区别);

MSE公式为:\(\frac{1}{N}\sum_{i=1}^{N}(w*x_i+b-y_i)^2\)

最小二乘法:最小指的是目标是min,二乘指的就是MSE中误差的二次方,公式为:\(min\frac{1}{N}\sum_{i=1}^{N}(w*x_i+b-y_i)^2\)

由于目标是查找拟合最好的超平面,因此依然定义变量wb

对于w和b的求解有两种方式:

列出最小化的公式,利用优化求解器求解:

基于已知的X、Y,未知的w、b构建MSE公式;

定义最小化MSE的目标函数;

利用求解器直接求解上述函数得到新的w和b;

对经验误差函数求偏导并令其为0推导出wb的解析解:

基于最小化MSE的优化问题可以直接推导出w和b的计算方法;

基于推导出的计算方法直接计算求解;

利用求解器求解

利用求解器求解可以看作就是个列公式的过程,把已知的数据X和Y,未知的变量w和b定义好,构建出MSE的公式,然后丢到求解器直接对w和b求偏导即可,相对来说代码繁琐,但是过程更简单,没有任何数学推导;

代码实现 初始化数据集 X = np.array([1.51, 1.64, 1.6, 1.73, 1.82, 1.87]) y = np.array([1.63, 1.7, 1.71, 1.72, 1.76, 1.86]) 定义变量符号

所谓变量指的就是那些需要求解的部分,次数就是超平面的w和b;

w,b = symbols('w b',real=True) 定义经验误差函数MSE RDh = 0 for xi,yi in zip(X,y): RDh += (yi - (w*xi+b))**2 RDh = RDh / len(X) 定义求解函数

此处就是对w和b求偏导;

eRDHw = diff(RDh,w) eRDHb = diff(RDh,b) 求解w和b ans = solve((eRDHw,eRDHb),(w,b)) w,b = ans[w],ans[b] 运行结果

手撸机器学习算法 - 线性回归

完整代码 from sympy import symbols, diff, solve import numpy as np import matplotlib.pyplot as plt ''' 线性回归拟合wx+b直线; 最小二乘法指的是优化求解过程是通过对经验误差(此处是均平方误差)求偏导并令其为0以解的w和b; ''' # 数据集 D X为父亲身高,Y为儿子身高 X = np.array([1.51, 1.64, 1.6, 1.73, 1.82, 1.87]) y = np.array([1.63, 1.7, 1.71, 1.72, 1.76, 1.86]) # 构造符号 w,b = symbols('w b',real=True) # 定义经验误差计算公式:(1/N)*sum(yi-(w*xi+b))^2) RDh = 0 for xi,yi in zip(X,y): RDh += (yi - (w*xi+b))**2 RDh = RDh / len(X) # 对w和b求偏导:求偏导的结果是得到两个结果为0的方程式 eRDHw = diff(RDh,w) eRDHb = diff(RDh,b) # 求解联立方程组 ans = solve((eRDHw,eRDHb),(w,b)) w,b = ans[w],ans[b] print('使得经验误差RDh取得最小值的参数为:'+str(ans)) plt.scatter(X,y) x_range = [min(X)-0.1,max(X)+0.1] y_range = [w*x_range[0]+b,w*x_range[1]+b] plt.plot(x_range,y_range) plt.show() 推导公式求解

与利用优化器求解的区别在于针对\(min\frac{1}{N}\sum_{i=1}^{N}(w*x_i+b-y_i)^2\)\(w\)\(b\)求偏导并令其为0,并推导出wb的计算公式是自己推导的,还是由优化器完成的,事实上如果自己推导,那么最终代码实现上会非常简单(推导过程不会出现在代码中);

w和b的求解公式推导

首先,我们的优化目标为:

\[min \frac{1}{N}\sum_{i=1}^{N}(w*x_i+b-y_i)^2 \]

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zzfjyd.html