Basic RNN、LSTM的前向传播和反向传播详细解析

Basic RNN、LSTM的前向传播和反向传播详细解析

Basic RNN、LSTM由于它们独特的架构,被大量应用在自然语言处理和序列模型的任务上。通过它们自身特殊的结构,可以记住之前的输入中的部分内容和信息,并对之后的输出产生影响。

  • 本文主要针对 :对RNN和LSTM有一定基础了解,但是对公式推导还不是完全掌握的童鞋(尤其是lstm的反向传播部分),欢迎各位批评指正~
  • 由于markdown编辑公式太麻烦了,所以公式也都是本地编辑之后的截图,有不正确的地方欢迎指正

Basic RNN架构简介

整体架构

模型的整体结构如下图所示,输入的是序列x、输出y,长度为Tx。当然还有针对输入输出不相等的RNN结构,这里只是为了详解RNN的公式推导,特别是反向传播的推导,所以不再赘述。
Basic RNN、LSTM的前向传播和反向传播详细解析

Figure 1: Basic RNN model

上图其实是沿时间轴展开的RNN模型,其实图中所有的RNN-cell都共用一套参数,每个cell输入当前时间点的输入x< t >和前一个cell输出的a< t-1 >,得到当前cell的输出a< t >和y< t >

BasicRNN 前向传播

  • 现在我们单独对每个cell进行公式推导,最终整个模型的公式其实就是单个cell的循环调用。
  • 下图是单个cell的具体结构图,以及前向传播的公式,非常的简洁明了
    Basic RNN、LSTM的前向传播和反向传播详细解析
    Figure 2: Basic RNN cell

BasicRNN 反向传播

针对前面介绍的每个cell前向传播图和公式,我们能很快的写出针对每个cell的反向传播公式:

Basic RNN、LSTM的前向传播和反向传播详细解析

Figure 3: Basic RNN BP

由前向传播的单个cell图,根据梯度反向传播易知。当前cell的Ja<t>由两部分构成:
- 当前cell的输出y^< t >与真实标签代入损失函数,通过损失函数对a< t >求导得到的梯度da< t >1
- 输入到下一个cell的a< t >传回的梯度da< t >2

公式推导前我们还需要知道tanh(x)x=1(tanh(x))2
下面给出BasicRNN的反向传播推导:
Basic RNN、LSTM的前向传播和反向传播详细解析

LSTM架构简介

模型的整体结构如下图所示,输入的是序列x、输出y,长度为Tx。同样的模型其实是在时间轴上的展开.
- 每个时间点输入当前x< t >、前一个时间的cell输出的a< t-1 >和c< t-1 >
- 输出y< t >、a< t >、c< t >

Basic RNN、LSTM的前向传播和反向传播详细解析

Figure 4: LSTM model

整个模型的和RNN十分相似,最大的区别就是改进了cell的结构,使得模型拥有了长期和短期记忆能力,更专业点的说法就是避免了梯度消失的问题。

LSTM前向传播

同样的我们将整个模型拆成一个cell区分析它的前向和反向传播过程,并且给出具体求导公式。
- 下面给出单个cell的前向传播示意图和具体传播公式:

Basic RNN、LSTM的前向传播和反向传播详细解析

Figure 5: LSTM cell

这里我们省略了每个cell输出y^< t >的计算公式,因为它和Basic RNN的公式完全一样。

LSTM反向传播

LSTM单个cell的反向传播比Basic RNN看起来要复杂很多,主要变化就是添加了三个门:遗忘门Γ< t >f、更新门Γ< t >u和输出门Γ< t >o。但是我们理清楚单个cell接收到的所有梯度,就很容易理解了。
- 首先我们可以看看图5的前向传播图,把箭头反向之后,其实就是我们的梯度反向传播图,这时我们可以在图中观察到两点
- 当前cell中a< t >通过反向传播得到的梯度同样有两个部分
- 当前输出y^< t >代入损失函数,对a< t >求导得到的da< t >1
- 输入到下一个cell的a< t >传回的梯度da< t >2
- 当前cell还要接受输入到下一个cell的c< t >传回的梯度dc< t >

公式推导前我们还需要知道:
- σ表示的是sigmoid**函数
- tanh(x)x=1(tanh(x))2
- σ(x)x=σ(x)(1σ(x))

下面给出LSTM的反向传播求导过程:
1)下面是当前cell通过反向传播得到的两部分梯度:
Basic RNN、LSTM的前向传播和反向传播详细解析
2)计算三个门和c~< t >的梯度:
Basic RNN、LSTM的前向传播和反向传播详细解析
- 公式(4)、(5)、(6)受步骤1)中的两部分梯度影响

3)由步骤2)得到的梯度,我们可以进一步计算出四个参数矩阵dWfdWudWcdWo的梯度:
Basic RNN、LSTM的前向传播和反向传播详细解析

4)dbfdbudbcdbo的求导对应步骤3),但是最后要把得到的矩阵按列求和:
Basic RNN、LSTM的前向传播和反向传播详细解析
5)由步骤1)2)所得,并且结合图五中的前向传播路径,可以进一步求出da< t-1 >dc< t-1 >dx< t >
Basic RNN、LSTM的前向传播和反向传播详细解析
- 公式(15)(17)由图5易知受三个门和c~< t >的反向传播梯度影响
- 公式(15)(17)其中的WfWfWfWf其实是不同的,(15)中的只取与a< t-1 > 对应的前半部本,(16)取的是与x< t >对应的后半部分


注:本博文图片和推导参考了Andrew的DeepLearning.ai课程中的RNN部分,但是课后习题给出的公式推导存在错误和前后数学符号不统一的情况,因而在其基础上写出本博文
如果公式中符号错误,欢迎批评指正