【公众号文章】——反向传播算法

链接:
https://mp.weixin.qq.com/s/YBOHt1WKuA1-NaqRtbimxA
原文:
http://www.cnblogs.com/pinard/p/6422831.html


1、DNNs反向传播算法要解决的问题

在了解DNNs的反向传播算法前,我们先要知道DNNs反向传播算法要解决的问题,也就是说,什么时候我们需要这个反向传播算法?

回到监督学习的一般问题中,假设我们有m个训练样本:
(x1,y1),(x2,y2),...,(xm,ym){(x_1, y_1), (x_2, y_2), ..., (x_m, y_m) }
其中x为输入向量,特征维度为n_in,而y为输出向量,特征维度为n_out。我们需要利用这m个样本训练出一个模型,当有一个新的测试样本(X_test, ?)来到时, 我们可以预测(Y_test, ?)向量的输出。

如果我们采用DNNs的模型,算上输入层n_in和输出层n_out的神经元,再加上若干隐藏层的神经元。此时我们需要找到适合所有隐藏层和输出层对应的线性系数矩阵W,偏倚向量b,让所有的训练样本输入计算出的输出尽可能的等于或很接近样本输出。怎么找到合适的参数呢?

如果大家对传统的机器学习的算法优化过程熟悉的话,这里就很容易联想到我们可以用一个合适的损失函数来度量训练样本的输出损失,接着对这个损失函数进行优化求最小化的极值,对应的一系列线性系数矩阵W,偏倚向量b即为我们的最终结果。

在DNNs中,损失函数优化极值求解的过程最常见的一般是通过梯度下降法来一步步迭代完成的,当然也可以是其他的迭代方法比如牛顿法与拟牛顿法。如果大家对梯度下降法不熟悉,建议阅读有关梯度下降(Gradient Descent)的相关知识。

总而言之,对DNNs的损失函数用梯度下降法进行迭代优化求极小值的过程即为我们的反向传播算法。


2、DNN反向传播算法的基本思路

在进行DNNs反响传播算法前,我们需要选择一个损失函数来度量训练样本计算出的输出和真实的训练样本输出之间的损失。那么,训练样本计算出的输出是怎么得来的呢?

这个输出是随机选择一系列W和b,是通过前向传播算法计算出来的。即通过一系列的计算如下,计算到输出层第L层对应的前向传播算法对应输出:
at=δ(zt)=δ(Wtat1+bl)a^t = δ(z^t) = δ(W^ta^{t-1}+b^l)

回到损失函数,DNNs可选择的损失函数有不少,为了专注算法,这里使用最常见的均方差来度量损失。即对于每一个样本,期望最小化下列式子:
J(W,b,x,y)=12aLy22J(W,b,x,y) = \frac12||a^L-y||^2_2其中y和aLa^L为特征维度的n_out向量,而||S||2为2的L2范数。

损失函数有了,现在我们开始用梯度下降法迭代求解每一层的W和b

首先是输出层L,注意到W和b满足以下式子:
aL=δ(aL)=δ(WLaL1+bL)a^L = δ(a^L) = δ(W^La^{L-1}+b^L)这样对于输出层的参数,我们损失函数变为:J(W,b,x,y)=12aLy22=12δ(WLaL1+bL)y22J(W,b,x,y) = \frac12||a^L-y||^2_2 = \frac12||δ(W^La^{L-1}+b^L)-y||^2_2

这样求解W,b的梯度就显得简单了许多,计算过程略。


3、DNNs反向传播算法的过程

由于梯度下降法有批量(Batch),小批量(mini-Batch),随机三个变种,为了简化描述,这里我们以最基本的批量梯度下降法为例来描述反向传播算法。实际上在业界使用最多的是mini-Batch的梯度下降法。不过区别仅仅在于迭代时训练样本的选择而已。
输入: 总层数L,以及各隐藏层与输出层的神经元个数,**函数,损失函数,迭代步长α,最大迭代次数MAX与停止迭代阈值ϵ,输入的m个训练样本(x1,y1),(x2,y2),...,(xm,ym)(x1,y1),(x2,y2),...,(xm,ym){(x_1, y_1), (x_2, y_2), ..., (x_m, y_m) }{(x_1, y_1), (x_2, y_2), ..., (x_m, y_m) }
输出:各隐藏层与输出层的线性关系系数矩阵W和偏倚向量b
①初始化各隐藏层与输出层的线性关系系数矩阵W和偏倚向量b的值为一个随机值。
②for iter to 1 to MAX:


(1)for i =1 to m:
【公众号文章】——反向传播算法
(2)for l = 2 to L, 更新第l层的W^l, b^l
【公众号文章】——反向传播算法
(3)如果所有W,b的变化值都小于停止迭代阈值ϵ,则跳出迭代循环到步骤③


③输出各隐藏层于输出层的先行关系系数矩阵W和偏倚向量b。