RNN梯度消失和梯度爆炸的原因

  RNN的提出是为了解决网络无法利用历史信息的问题,但由于RNN具有梯度消失和梯度爆炸的问题,导致RNN不能存储长期记忆。

网络结构

  首先来看RNN的结构,如下图1所示:

RNN梯度消失和梯度爆炸的原因

  上图的结构很好理解,xtx_{t}为网络输入,AA为隐藏层,hth_{t}为网络输出。既然我们想利用之前的历史信息,那我们就将网络在上一时刻的输出保存下来,作为当前时刻的输入,也就是上图中的反馈连接。我们将上图中的RNN结构按时序展开,如下图2所示:

RNN梯度消失和梯度爆炸的原因

  x0x_{0}~xtx_{t}是网络在不同时刻的输入,h0h_{0} ~hth_{t}是网络在不同时刻的输入,A是隐藏层。需要注意的是,上图中的RNN展开图是RNN按时序的展开图,并不是真正的拓扑结构,对于某一固定的时刻tt,RNN的结构就是图1;这是很多资料容易让人产生误解的地方。所以,图2中的那么多A其实是同一个隐藏层,这也就是RNN中的“参数共享”。当然,你也可以增加RNN的深度,即增加隐藏层,如下图3所示:

RNN梯度消失和梯度爆炸的原因

如上图所示,纵向是增加网络深度,横向是增加时间步。

工作原理

  介绍了RNN的网络结构,下面来看RNN的工作过程。我们假设网络只有一个隐藏层,网络输入为xx,输出为yy,隐藏层状态为hh,如下图4所示,

RNN梯度消失和梯度爆炸的原因

则在时刻tt有:

ht=f(wix+whht1) h_{t}=f(w_{i}x+w_{h}h_{t-1})
yt=f(woht) y_{t}=f(w_{o}h_{t})

上式中,ff为**函数,一般为sigmoidsigmoidtanhtanh

梯度消失与梯度爆炸

  了解了RNN的工作原理,下面我们就可以去分析RNN梯度消失和梯度爆炸的原因了。为了简化问题,只考虑三个时间步,如下图5所示:

RNN梯度消失和梯度爆炸的原因

则有:

h1=f(wix1+whh0),y1=f(woh1) h_{1}=f(w_{i}x_{1}+w_{h}h_{0}) , y_{1}=f(w_{o}h_{1})

h2=f(wix2+whh1),y2=f(woh2) h_{2}=f(w_{i}x_{2}+w_{h}h_{1}) , y_{2}=f(w_{o}h_{2})

h3=f(wix3+whh2),y3=f(woh3) h_{3}=f(w_{i}x_{3}+w_{h}h_{2}) , y_{3}=f(w_{o}h_{3})

RNN的损失函数为

L=t=0TLt=t=0Tg(yt) L=\sum_{t=0}^{T}L_{t}=\sum_{t=0}^{T}g(y_{t})

LtL_{t}tt时刻输出的损失,gg为网络的损失函数。根据链式求导法则,求L对各个参数的偏导即为参数更新的梯度。

先只考虑L3L_{3}求偏导,有:

L3wo=L3y3y3wo \frac{\partial L_{3}}{\partial w_{o}}=\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial w_{o}}

L3wi=L3y3y3h3h3wi+L3y3y3h3h3h2h2wi+L3y3y3h3h3h2h2h1h1wi\frac{\partial L_{3}}{\partial w_{i}}=\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial w_{i}}+\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial h_{2}}\frac{\partial h_{2}}{\partial w_{i}}+\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial h_{2}}\frac{\partial h_{2}}{\partial h_{1}}\frac{\partial h_{1}}{\partial w_{i}}

L3wh=L3y3y3h3h3wh+L3y3y3h3h3h2h2wh+L3y3y3h3h3h2h2h1h1wh\frac{\partial L_{3}}{\partial w_{h}}=\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial w_{h}}+\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial h_{2}}\frac{\partial h_{2}}{\partial w_{h}}+\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial h_{2}}\frac{\partial h_{2}}{\partial h_{1}}\frac{\partial h_{1}}{\partial w_{h}}

观察上式,由于ht,t(0,T)h_{t},t\in (0,T)的存在,使得损失函数对参数求偏导的过程中存在大量的复合求导。再将上述等式推广到所有时间步,则有

Lwo=t=0TLtytytwo\frac{\partial L}{\partial w_{o}}=\sum_{t=0}^{T}\frac{\partial L_{t}}{\partial y_{t}}\frac{\partial y_{t}}{\partial w_{o}}

Lwi=t=0Tj=0tLtytytht(k=j+1thkhk1)hjwi\frac{\partial L}{\partial w_{i}}=\sum_{t=0}^{T}\sum_{j=0}^{t}\frac{\partial L_{t}}{\partial y_{t}}\frac{\partial y_{t}}{\partial h_{t}}(\prod_{k=j+1}^{t}\frac{\partial h_{k}}{\partial h_{k-1}})\frac{\partial h_{j}}{\partial w_{i}}

Lwh=t=0Tj=0tLtytytht(k=j+1thkhk1)hjwh\frac{\partial L}{\partial w_{h}}=\sum_{t=0}^{T}\sum_{j=0}^{t}\frac{\partial L_{t}}{\partial y_{t}}\frac{\partial y_{t}}{\partial h_{t}}(\prod_{k=j+1}^{t}\frac{\partial h_{k}}{\partial h_{k-1}})\frac{\partial h_{j}}{\partial w_{h}}

推导到这里,RNN梯度消失和梯度爆炸的原因就产生了。上述的第二个和第三个等式中出现了与时间tt相关的连乘的因式,根据第二节中RNN工作原理的介绍,以第二个等式同理,

hkhk1=fwi \frac{\partial h_{k}}{\partial h_{k-1}}=f^{'}\cdot w_{i}

其中ff^{'}为**函数的导数,以sigmoidsigmoid函数为例,f(0,1)f\in(0,1)其导数为f=f(1f)(0,14)f^{'}=f(1-f)\in(0,\frac{1}{4}),则wi<1w_{i}<1时,hkhk1<1\frac{\partial h_{k}}{\partial h_{k-1}}<1,经过数次相乘后,Lwi\frac{\partial L}{\partial w_{i}}逐渐接近于0,即梯度消失;wi>4w_{i}>4时,hkhk1>1\frac{\partial h_{k}}{\partial h_{k-1}}>1,经过数次相乘后,Lwi\frac{\partial L}{\partial w_{i}}越来越大,即梯度爆炸。

  至此,我们就从理论上分析了RNN中存在梯度消失和梯度爆炸的原因。但为了能够使用RNN利用历史信息的特性,对RNN的结构进行适当的改造就能得到性能更加优越的LSTM。LSTM的结构大大缓解了传统RNN中存在的梯队消失和梯度爆炸的问题,从而使时间步能够大大增长。具体的分析请参考下一篇文章。