理解BPTT及RNN的梯度消失与梯度爆炸

前言

上篇文章RNN详解已经介绍了RNN的结构和前向传播的计算公式,这篇文章讲一下RNN的反向传播算法BPTT,及RNN梯度消失和梯度爆炸的原因。


BPTT

RNN的反向传播,也称为基于时间的反向传播算法BPTT(back propagation through time)。对所有参数求损失函数的偏导,并不断调整这些参数使得损失函数变得尽可能小。

先贴出RNN的结构图以供观赏,下面讲的都是图中的单层单向RNN:
理解BPTT及RNN的梯度消失与梯度爆炸

图片来自:Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs

反向传播做的是:算出参数梯度更新参数。参数梯度=损失函数对参数的偏导数Lw\frac{\partial L}{\partial w},参数更新公式已知:w=wαLww=w-\alpha \frac{\partial L}{\partial w}α\alpha是学习率。所以我们只需要求出Lw\frac{\partial L}{\partial w}就可以了。

分两部分:1.定义损失函数 LL2.再求出损失函数对参数的偏导数 Lw\frac{\partial L}{\partial w} ,下面开始啦,以参数 ww 为例:

1.定义损失函数:

假设时刻 tt 的损失函数为:Lt=12(Y3O3)2L_t=\frac{1}{2}(Y_3-O_3)^2

损失函数是均方差也好交叉熵也好都无所谓,这里只是举个例子,假定它是均方差。

因为有多个时刻,所以总损失函数为所有时刻的损失函数之和
L=t=0TLt              (1)L = \sum_{t=0}^{T}L_t \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1)

第一步完成,损失函数get!

2.求损失函数对参数的偏导数:

ww在每一时刻都出现了,所以 ww 在时刻 tt 的梯度=时刻 tt 的损失函数对所有时刻ww 的梯度
Ltw=s=0TLtws     (2)\frac{\partial L_t}{\partial w}=\sum_{s=0}^{T}\frac{\partial L_t}{\partial w_s}\ \ \ \ \ (2)

将(2)代入(1)可得下面的结果,ww 的总梯度 Lw\frac{\partial L}{\partial w}等于ww所有时刻的梯度和:
Lw=t=0TLtw=t=0Ts=0TLtws \begin{aligned} \frac{\partial L}{\partial w} &=\sum_{t=0}^{T}\frac{\partial L_t}{\partial w}\\ &=\sum_{t=0}^{T}\sum_{s=0}^{T}\frac{\partial L_t}{\partial w_s}\\ \end{aligned}

第二步完成,梯度get!

3.更新参数

有了梯度就可以更新参数了:w=wαLww=w-\alpha \frac{\partial L}{\partial w}

以上三步就是针对参数ww的一次反向传播,个人认为了解这些就可以了,但是如果你想看BPTT更详尽的数学公式推导,我建议你看一下这篇文章


梯度消失

下面是RNN前向传播公式,其中ff一般是softmaxsoftmax函数,gg一般是tanhtanh函数。
st=f(Uxt+Wst1+b)ot=g(Vst+c) \begin{aligned} s_t &=f(Ux_t + Ws_{t-1}+b)\\ o_t &=g(Vs_t+c)\\ \end{aligned}

假设我们有一个RNN,时间序列只有三个时刻,下面是其结构图:
理解BPTT及RNN的梯度消失与梯度爆炸

前向传播:
s1=f(Ux1+Ws0+b)      ot=g(Vs1+c)s2=f(Ux2+Ws1+b)      ot=g(Vs2+c)s3=f(Ux3+Ws2+b)      ot=g(Vs3+c) s_1 =f(Ux_1 + Ws_0+b) \ \ \ \ \ \ o_t =g(Vs1+c)\\ s_2 =f(Ux_2 + Ws_1+b) \ \ \ \ \ \ o_t =g(Vs2+c)\\ s_3 =f(Ux_3 + Ws_2+b) \ \ \ \ \ \ o_t =g(Vs3+c)\\

反向传播:
我们现在只对t=3时刻的UVWU、V、W求损失函数L3L_3的偏导(其他时刻类似):
L3U=L3O3O3S3S3U+L3O3O3S3S3S2S2U+L3O3O3S3S3S2S2S1S1UL3V=L3S3S3VL3W=L3O3O3S3S3W+L3O3O3S3S3S2S2W+L3O3O3S3S3S2S2S1S1W \begin{aligned} \frac{\partial L_3}{\partial U} &= \frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}\frac{\partial S_3}{\partial U}+\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}{\color{Red}{\frac{\partial S_3}{\partial S_2}}}\frac{\partial S_2}{\partial U}+\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}{\color{Red}{\frac{\partial S_3}{\partial S_2}\frac{\partial S_2}{\partial S_1}}}\frac{\partial S_1}{\partial U}\\ \frac{\partial L_3}{\partial V} &= \frac{\partial L_3}{\partial S_3}\frac{\partial S_3}{\partial V}\\ \frac{\partial L_3}{\partial W} &= \frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}\frac{\partial S_3}{\partial W}+\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}{\color{Red}{\frac{\partial S_3}{\partial S_2}}}\frac{\partial S_2}{\partial W}+\frac{\partial L_3}{\partial O_3}\frac{\partial O_3}{\partial S_3}{\color{Red}{\frac{\partial S_3}{\partial S_2}\frac{\partial S_2}{\partial S_1}}}\frac{\partial S_1}{\partial W} \end{aligned}

t=3t=3 时刻加上之前的时刻,一共是3,等于上下两个式子的加数。若 tt 足够大,则式中的加数就会很多,红色部分的项数也越多。

根据上述公式,我们可以得出任意时刻t时对UWU、W求偏导得公式,以WW为例:
LtW=k=0tLtOtOtSt(j=k+1tSjSj1)SjW        (3)\frac{\partial L_t}{\partial W}=\sum_{k=0}^t\frac{\partial L_t}{\partial O_t}\frac{\partial O_t}{\partial S_t}({\color{Red}{\prod_{j=k+1}^t\frac{\partial S_j}{\partial S_{j-1}}}})\frac{\partial S_j}{\partial W} \ \ \ \ \ \ \ \ (3)

SjS_j展开:sj=tanh(Uxj+Wsj1+b)s_j =tanh(Ux_j + Ws_{j-1}+b)

则式(3)中的j=k+1tSjSj1{\color{Red}{\prod_{j=k+1}^t\frac{\partial S_j}{\partial S_{j-1}}}}就变成了:j=k+1ttanhW{\color{Red}{\prod_{j=k+1}^ttanh'W}}

**函数tanh和它的导数图像如下
理解BPTT及RNN的梯度消失与梯度爆炸

图片来自Recurrent Neural Network系列3–理解RNN的BPTT算法和梯度消失

可以看出tanh1tanh'\leq 1,训练过程中几乎都是小于1的,而W的值一般会处于0~1之间,当时间序列足够长,即t足够大时,j=k+1ttanhW{\color{Red}{\prod_{j=k+1}^ttanh'W}}就会趋近于0,这就造成了梯度消失;当W的值很大(一般为初始化不当引起)时,j=k+1ttanhW{\color{Red}{\prod_{j=k+1}^ttanh'W}}就会趋近于无穷,这就造成了梯度爆炸。

RNN梯度消失问题很常见,题都爆炸问题一般不常见。RNN中的梯度消失会造成什么后果呢?会使RNN的长时记忆失效,简而言之就是会忘记很久之前的信息,记性不好。

至于怎么避免RNN的梯度消失和梯度都爆炸,我们现在知道,造成这种现象的根本原因就在于j=k+1tSjSj1{\color{Red}{\prod_{j=k+1}^t\frac{\partial S_j}{\partial S_{j-1}}}}这个连乘式,我们可以使这个连乘式中每项的偏导 SjSj10\frac{\partial S_j}{\partial S_{j-1}}\approx 0SjSj11\frac{\partial S_j}{\partial S_{j-1}}\approx 1。这就是LSTM做的事情。下一篇文章介绍LSTM。

references

[1] 零基础入门深度学习(5) - 循环神经网络
[2] Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
[3] Recurrent Neural Network系列3–理解RNN的BPTT算法和梯度消失
[4] On the difficulty of training recurrent neural networks