机器学习之路: python线性回归 过拟合 L1与L2正则化

git:https://github.com/linyi0604/MachineLearning

正则化:
提高模型在未知数据上的泛化能力
避免参数过拟合
正则化常用的方法:
在目标函数上增加对参数的惩罚项
削减某一参数对结果的影响力度

L1正则化:lasso
在线性回归的目标函数后面加上L1范数向量惩罚项。

f = w * x^n + b + k * ||w||1

x为输入的样本特征
w为学习到的每个特征的参数
n为次数
b为偏置、截距
||w||1 为 特征参数的L1范数,作为惩罚向量
k 为惩罚的力度

L2范数正则化:ridge
在线性回归的目标函数后面加上L2范数向量惩罚项。

f = w * x^n + b + k * ||w||2

x为输入的样本特征
w为学习到的每个特征的参数
n为次数
b为偏置、截距
||w||2 为 特征参数的L2范数,作为惩罚向量
k 为惩罚的力度


下面模拟 根据蛋糕的直径大小 预测蛋糕价格
采用了4次线性模型,是一个过拟合的模型
分别使用两个正则化方法 进行学习和预测


1 from sklearn.linear_model import LinearRegression, Lasso, Ridge 2 # 导入多项式特征生成器 3 from sklearn.preprocessing import PolynomialFeatures 4 5 6 \'\'\' 7 正则化: 8 提高模型在未知数据上的泛化能力 9 避免参数过拟合 10 正则化常用的方法: 11 在目标函数上增加对参数的惩罚项 12 削减某一参数对结果的影响力度 13 14 L1正则化:lasso 15 在线性回归的目标函数后面加上L1范数向量惩罚项。 16 17 f = w * x^n + b + k * ||w||1 18 19 x为输入的样本特征 20 w为学习到的每个特征的参数 21 n为次数 22 b为偏置、截距 23 ||w||1 为 特征参数的L1范数,作为惩罚向量 24 k 为惩罚的力度 25 26 L2范数正则化:ridge 27 在线性回归的目标函数后面加上L2范数向量惩罚项。 28 29 f = w * x^n + b + k * ||w||2 30 31 x为输入的样本特征 32 w为学习到的每个特征的参数 33 n为次数 34 b为偏置、截距 35 ||w||2 为 特征参数的L2范数,作为惩罚向量 36 k 为惩罚的力度 37 38 39 下面模拟 根据蛋糕的直径大小 预测蛋糕价格 40 采用了4次线性模型,是一个过拟合的模型 41 分别使用两个正则化方法 进行学习和预测 42 43 \'\'\' 44 45 # 样本的训练数据,特征和目标值 46 x_train = [[6], [8], [10], [14], [18]] 47 y_train = [[7], [9], [13], [17.5], [18]] 48 # 准备测试数据 49 x_test = [[6], [8], [11], [16]] 50 y_test = [[8], [12], [15], [18]] 51 # 进行四次线性回归模型拟合 52 poly4 = PolynomialFeatures(degree=4) # 4次多项式特征生成器 53 x_train_poly4 = poly4.fit_transform(x_train) 54 # 建立模型预测 55 regressor_poly4 = LinearRegression() 56 regressor_poly4.fit(x_train_poly4, y_train) 57 x_test_poly4 = poly4.transform(x_test) 58 print("四次线性模型预测得分:", regressor_poly4.score(x_test_poly4, y_test)) # 0.8095880795746723 59 60 # 采用L1范数正则化线性模型进行学习和预测 61 lasso_poly4 = Lasso() 62 lasso_poly4.fit(x_train_poly4, y_train) 63 print("L1正则化的预测得分为:", lasso_poly4.score(x_test_poly4, y_test)) # 0.8388926873604382 64 65 # 采用L2范数正则化线性模型进行学习和预测 66 ridge_poly4 = Ridge() 67 ridge_poly4.fit(x_train_poly4, y_train) 68 print("L2正则化的预测得分为:", ridge_poly4.score(x_test_poly4, y_test)) # 0.8374201759366456

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zzfpff.html