线性回归损失函数求解 (2)

这样的话,我们记\(J(w_1, w_2, b)\)为一个函数,求一个多元函数的最值我们在微积分中学到过就是求\(\frac{\partial J}{\partial w_1}, \frac{\partial J}{\partial w_2}, \frac{\partial J}{\partial b}\),并且令它们都等于0,就能求出最终的解了。

这里已经涉及到矩阵微积分的内容,我试着写几步:\[ \begin{equation} J(w_1, w_2, b) = (price-y)^{\mathrm{T}}(price-y) \end{equation} \]
\(price\)\(y\)都是向量,再将\(price\)用参数\(w_1,w_2,b\)表示:\[ \begin{equation} J(w) = (Xw-y)^{\mathrm{T}}(Xw-y) \end{equation} \]
(8)式中,\(X\)的每一行是1组数据,它是一个nx3的矩阵;\(w\)是个向量\[ X=\left[ \begin{matrix} 第一笔数据的 \ x1 & x2 & 1 \\ 第二笔数据的 \ x1 & x2 & 1 \\ . \\ . \\ . \\ 第n笔数据的 \ x1 & x2 & 1 \\ \end{matrix} \right] \ \ \ \ \ \ \ \ w =\left[ \begin{matrix} w_1\\ w_2\\ b \end{matrix} \right] \]
继续将(8)式化简\[ \begin{equation} J(w) = (w^{\mathrm{T}}X^{\mathrm{T}}-y^{\mathrm{T}})(Xw-y) \end{equation} \]
接着去括号\[ \begin{equation} J(w) = w^{\mathrm{T}}X^{\mathrm{T}}Xw-y^{\mathrm{T}}Xw-w^{\mathrm{T}}X^{\mathrm{T}}y+y^{\mathrm{T}}y \end{equation} \]
其中,\(y^{\mathrm{T}}Xw\)\(w^{\mathrm{T}}X^{\mathrm{T}}y\)是相等的,都是一个数,所以最终可以写为\[ \begin{equation} J(w) = w^{\mathrm{T}}X^{\mathrm{T}}Xw-2w^{\mathrm{T}}X^{\mathrm{T}}y+y^{\mathrm{T}}y \end{equation} \]
下面就要进行矩阵微积分了,讲实话我不会。但是我学会两个trick能求出最终的\(w\)

第一个trick来自台大的林轩田老师,我记得他很轻松地说可以把上面这个等式变换成我们会的一元二次等式,我当时带着满腹的怀疑按照他说的做了,不过真的得到了结果(惊吓!可能这就是数学的魅力)。我们将(11)式变为\[ \begin{equation} J(x) = w^{\mathrm{T}}Aw - 2w^{\mathrm{T}}b + c \\ subject \ to \ A = X^{\mathrm{T}}X \\ \ \ \ \ \ \ \ \ \ \ \ \ b=X^{\mathrm{T}}y\\ \ \ \ \ \ \ \ \ \ \ \ \ c=y^{\mathrm{T}}y \end{equation} \]
当然,这不是严格意义上的转换,但是真的能让我们像解熟悉的一元二次方程一样求出解。对(12)求导令其为0,再将原来的值代入回去能得到\[ \begin{equation} 2X^{\mathrm{T}}Xw - 2X^{\mathrm{T}}y = 0 \end{equation} \]
最终\[ \begin{equation} w = (X^{\mathrm{T}}X)^{-1}X^{\mathrm{T}}y \end{equation} \]

第二种求解的办法就是记住矩阵微积分的公式:

y \(\frac{\partial y}{\partial X}\)
\(AX\)   \(A^{\mathrm{T}}\)  
\(X^{\mathrm{T}}A\)   \(A\)  
\(X^{\mathrm{T}}X\)   \(2X\)  
\(X^{\mathrm{T}}AX\)   \(AX+A^{\mathrm{T}}X\)  

等等,(14)式好熟悉。这不就是求解线性方程组\(Ax=b\)这个方程组无解时的最优近似解么。所以,机器学习的线性回归其实就是最小二乘中的拟合问题。一开始就将这个问题看为求解线性方程组问题的话:\[ \left[ \begin{matrix} 第一笔数据的 \ x1 & x2 & 1 \\ 第二笔数据的 \ x1 & x2 & 1 \\ . \\ . \\ . \\ 第n笔数据的 \ x1 & x2 & 1 \\ \end{matrix} \right] \left[ \begin{matrix} w_1\\ w_2\\ b \end{matrix} \right]=\left[ \begin{matrix} 第一笔数据的 \ price \\ 第二笔数据的 \ price \\ . \\ . \\ . \\ 第n笔数据的 \ price \\ \end{matrix} \right] \]
不就是求这个方程组有没有解么?如果没有解,我们就求近似解。这个近似解的求解方法就是上一篇笔记中一直强调的部分,在等式左右两边左乘矩阵的转置,我们马上能得到近似解。

画图代码 import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib import cm x1 = np.linspace(-5, 5, 5) x2 = x1 x1, x2 = np.meshgrid(x1, x2) price = x1 * 3 + x2 * 4 - 5 np.random.seed(325) data_x = np.random.randint(-5, 5, 5) data_y = np.random.randint(-5, 5, 5) data_z = data_x * 3 + data_y * 4 - 5 bias = np.array([5, 2, -3, 4, -3]) data_z = data_z + bias fig = plt.figure() ax = fig.gca(projection='3d') ax.plot_wireframe(x1, x2, price, rstride=10, cstride=10) for i in range(len(data_x)): ax.scatter(data_x[i], data_y[i], data_z[i], color='r') ax.set_xlabel('x1') ax.set_ylabel('x2') ax.set_zlabel('price') ax.set_xticks([-5, 0, 5]) ax.set_yticks([-5, 0,10]) ax.set_zticks([ -40, 0, 40]) plt.show()

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/wpjgjs.html