前言
上篇文章RNN详解已经介绍了RNN的结构和前向传播的计算公式,这篇文章讲一下RNN的反向传播算法BPTT,及RNN梯度消失和梯度爆炸的原因。
BPTT
RNN的反向传播,也称为基于时间的反向传播算法BPTT(back propagation through time)。对所有参数求损失函数的偏导,并不断调整这些参数使得损失函数变得尽可能小。
先贴出RNN的结构图以供观赏,下面讲的都是图中的单层单向RNN:
![理解BPTT及RNN的梯度消失与梯度爆炸 理解BPTT及RNN的梯度消失与梯度爆炸](/default/index/img?u=aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzE5MS9kYTBiNGM3ZmJjYjJmZGUwYzIxMWFkNjBhOTg1MjU1Ny5KUEVH)
图片来自:Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
反向传播做的是:算出参数梯度并更新参数。参数梯度=损失函数对参数的偏导数∂w∂L,参数更新公式已知:w=w−α∂w∂L,α是学习率。所以我们只需要求出∂w∂L就可以了。
分两部分:1.定义损失函数 L,2.再求出损失函数对参数的偏导数 ∂w∂L ,下面开始啦,以参数 w 为例:
1.定义损失函数:
假设时刻 t 的损失函数为:Lt=21(Y3−O3)2
损失函数是均方差也好交叉熵也好都无所谓,这里只是举个例子,假定它是均方差。
因为有多个时刻,所以总损失函数为所有时刻的损失函数之和:
L=t=0∑TLt (1)
第一步完成,损失函数get!
2.求损失函数对参数的偏导数:
w在每一时刻都出现了,所以 w 在时刻 t 的梯度=时刻 t 的损失函数对所有时刻的 w 的梯度和:
∂w∂Lt=s=0∑T∂ws∂Lt (2)
将(2)代入(1)可得下面的结果,w 的总梯度 ∂w∂L等于w在所有时刻的梯度和:
∂w∂L=t=0∑T∂w∂Lt=t=0∑Ts=0∑T∂ws∂Lt
第二步完成,梯度get!
3.更新参数
有了梯度就可以更新参数了:w=w−α∂w∂L
以上三步就是针对参数w的一次反向传播,个人认为了解这些就可以了,但是如果你想看BPTT更详尽的数学公式推导,我建议你看一下这篇文章。
梯度消失
下面是RNN前向传播公式,其中f一般是softmax函数,g一般是tanh函数。
stot=f(Uxt+Wst−1+b)=g(Vst+c)
假设我们有一个RNN,时间序列只有三个时刻,下面是其结构图:
![理解BPTT及RNN的梯度消失与梯度爆炸 理解BPTT及RNN的梯度消失与梯度爆炸](/default/index/img?u=aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzM1OC8zNmVmNWQ2MDBkNzQ5Y2RmYjc0NTk0NTIwZDI5ZGViZS5KUEVH)
前向传播:
s1=f(Ux1+Ws0+b) ot=g(Vs1+c)s2=f(Ux2+Ws1+b) ot=g(Vs2+c)s3=f(Ux3+Ws2+b) ot=g(Vs3+c)
反向传播:
我们现在只对t=3时刻的U、V、W求损失函数L3的偏导(其他时刻类似):
∂U∂L3∂V∂L3∂W∂L3=∂O3∂L3∂S3∂O3∂U∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂U∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂U∂S1=∂S3∂L3∂V∂S3=∂O3∂L3∂S3∂O3∂W∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂W∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂W∂S1
t=3 时刻加上之前的时刻,一共是3,等于上下两个式子的加数。若 t 足够大,则式中的加数就会很多,红色部分的项数也越多。
根据上述公式,我们可以得出任意时刻t时对U、W求偏导得公式,以W为例:
∂W∂Lt=k=0∑t∂Ot∂Lt∂St∂Ot(j=k+1∏t∂Sj−1∂Sj)∂W∂Sj (3)
把Sj展开:sj=tanh(Uxj+Wsj−1+b)
则式(3)中的∏j=k+1t∂Sj−1∂Sj就变成了:∏j=k+1ttanh′W
**函数tanh和它的导数图像如下
![理解BPTT及RNN的梯度消失与梯度爆炸 理解BPTT及RNN的梯度消失与梯度爆炸](/default/index/img?u=aHR0cHM6Ly9waWFuc2hlbi5jb20vaW1hZ2VzLzk2Mi8zYWI4MGM2YTE5Njk4ZjA0YWRiY2NlMDY1NzQyYzAwMi5KUEVH)
图片来自Recurrent Neural Network系列3–理解RNN的BPTT算法和梯度消失
可以看出tanh′≤1,训练过程中几乎都是小于1的,而W的值一般会处于0~1之间,当时间序列足够长,即t足够大时,∏j=k+1ttanh′W就会趋近于0,这就造成了梯度消失;当W的值很大(一般为初始化不当引起)时,∏j=k+1ttanh′W就会趋近于无穷,这就造成了梯度爆炸。
RNN梯度消失问题很常见,题都爆炸问题一般不常见。RNN中的梯度消失会造成什么后果呢?会使RNN的长时记忆失效,简而言之就是会忘记很久之前的信息,记性不好。
至于怎么避免RNN的梯度消失和梯度都爆炸,我们现在知道,造成这种现象的根本原因就在于∏j=k+1t∂Sj−1∂Sj这个连乘式,我们可以使这个连乘式中每项的偏导 ∂Sj−1∂Sj≈0 或 ∂Sj−1∂Sj≈1。这就是LSTM做的事情。下一篇文章介绍LSTM。
references
[1] 零基础入门深度学习(5) - 循环神经网络
[2] Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
[3] Recurrent Neural Network系列3–理解RNN的BPTT算法和梯度消失
[4] On the difficulty of training recurrent neural networks