RNN的提出是为了解决网络无法利用历史信息的问题,但由于RNN具有梯度消失和梯度爆炸的问题,导致RNN不能存储长期记忆。
网络结构
首先来看RNN的结构,如下图1所示:
上图的结构很好理解,xt为网络输入,A为隐藏层,ht为网络输出。既然我们想利用之前的历史信息,那我们就将网络在上一时刻的输出保存下来,作为当前时刻的输入,也就是上图中的反馈连接。我们将上图中的RNN结构按时序展开,如下图2所示:
x0~xt是网络在不同时刻的输入,h0 ~ht是网络在不同时刻的输入,A是隐藏层。需要注意的是,上图中的RNN展开图是RNN按时序的展开图,并不是真正的拓扑结构,对于某一固定的时刻t,RNN的结构就是图1;这是很多资料容易让人产生误解的地方。所以,图2中的那么多A其实是同一个隐藏层,这也就是RNN中的“参数共享”。当然,你也可以增加RNN的深度,即增加隐藏层,如下图3所示:
如上图所示,纵向是增加网络深度,横向是增加时间步。
工作原理
介绍了RNN的网络结构,下面来看RNN的工作过程。我们假设网络只有一个隐藏层,网络输入为x,输出为y,隐藏层状态为h,如下图4所示,
则在时刻t有:
ht=f(wix+whht−1)
yt=f(woht)
上式中,f为**函数,一般为sigmoid或tanh。
梯度消失与梯度爆炸
了解了RNN的工作原理,下面我们就可以去分析RNN梯度消失和梯度爆炸的原因了。为了简化问题,只考虑三个时间步,如下图5所示:
则有:
h1=f(wix1+whh0),y1=f(woh1)
h2=f(wix2+whh1),y2=f(woh2)
h3=f(wix3+whh2),y3=f(woh3)
RNN的损失函数为
L=t=0∑TLt=t=0∑Tg(yt),
Lt为t时刻输出的损失,g为网络的损失函数。根据链式求导法则,求L对各个参数的偏导即为参数更新的梯度。
先只考虑L3求偏导,有:
∂wo∂L3=∂y3∂L3∂wo∂y3
∂wi∂L3=∂y3∂L3∂h3∂y3∂wi∂h3+∂y3∂L3∂h3∂y3∂h2∂h3∂wi∂h2+∂y3∂L3∂h3∂y3∂h2∂h3∂h1∂h2∂wi∂h1
∂wh∂L3=∂y3∂L3∂h3∂y3∂wh∂h3+∂y3∂L3∂h3∂y3∂h2∂h3∂wh∂h2+∂y3∂L3∂h3∂y3∂h2∂h3∂h1∂h2∂wh∂h1
观察上式,由于ht,t∈(0,T)的存在,使得损失函数对参数求偏导的过程中存在大量的复合求导。再将上述等式推广到所有时间步,则有
∂wo∂L=t=0∑T∂yt∂Lt∂wo∂yt
∂wi∂L=t=0∑Tj=0∑t∂yt∂Lt∂ht∂yt(k=j+1∏t∂hk−1∂hk)∂wi∂hj
∂wh∂L=t=0∑Tj=0∑t∂yt∂Lt∂ht∂yt(k=j+1∏t∂hk−1∂hk)∂wh∂hj
推导到这里,RNN梯度消失和梯度爆炸的原因就产生了。上述的第二个和第三个等式中出现了与时间t相关的连乘的因式,根据第二节中RNN工作原理的介绍,以第二个等式同理,
∂hk−1∂hk=f′⋅wi
其中f′为**函数的导数,以sigmoid函数为例,f∈(0,1)其导数为f′=f(1−f)∈(0,41),则wi<1时,∂hk−1∂hk<1,经过数次相乘后,∂wi∂L逐渐接近于0,即梯度消失;wi>4时,∂hk−1∂hk>1,经过数次相乘后,∂wi∂L越来越大,即梯度爆炸。
至此,我们就从理论上分析了RNN中存在梯度消失和梯度爆炸的原因。但为了能够使用RNN利用历史信息的特性,对RNN的结构进行适当的改造就能得到性能更加优越的LSTM。LSTM的结构大大缓解了传统RNN中存在的梯队消失和梯度爆炸的问题,从而使时间步能够大大增长。具体的分析请参考下一篇文章。