Linear Regression
简介
线性回归是一种回归学习方法,一般用于处理连续性变量,算是机器学习的入门算法。虽然线性模型的形式很简单,但是线性模型的思想是很重要的,许多非线性模型都是在线性模型的基础上通过引入高维映射而得。
- 建模速度快,不需要复杂计算
- 可解释性好
- 不适用与非线性数据
- 可能出现过拟合
基本原理
给定数据集D={(x1,y1),...,(xm,ym},其中xi=(xi1,...,xid),线性回归模型试图学习到y^=wTx+b,使得y^近似等于y。
一般选用均方误差(mean square error, MSE),采用**最小二乘法(least square method)**求解,简单来说就是找到一条直线,使所有样本到直线上的欧氏距离之和最小。
均方误差即L=2m1Σi=1m(y^−y)2,这里乘了21是为了使后面的计算式更为简洁。
基本思路:首先赋予w、b初始值,用链式法则求出梯度,沿着梯度的反方向不断更新参数,使损失函数不断减小至收敛。具体求法为:
∂w∂L=∂y^∂L∂w∂y^=m1Σi=0m(y^i−yi)xi
∂b∂L=∂y^∂L∂b∂y^=m1Σi=0m(y^i−yi)
参数更新:
wj←wj+α(y−y^)xj
b←b+α(y−y^)
其中α称为学习率(learning rate)。
reference:
机器学习中的五种回归模型及其优缺点