本篇文章参考于 RNN梯度消失和爆炸的原因、Towser关于LSTM如何来避免梯度弥散和梯度爆炸?的问题解答、Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass。
看本篇文章之前,建议自行学习RNN和LSTM的前向和反向传播过程,学习教程可参考刘建平老师博客循环神经网络(RNN)模型与前向反向传播算法、LSTM模型与前向反向传播算法。
具体了解LSTM如何解决RNN所带来的梯度消失问题之前,我们需要明白为什么RNN会带来梯度消失问题。
1. RNN梯度消失原因
如上图所示,为RNN模型结构,前向传播过程包括,
-
隐藏状态:h(t)=σ(z(t))=σ(Ux(t)+Wh(t−1)+b),此处**函数一般为tanh。
-
模型输出:o(t)=Vh(t)+c
-
预测输出:y^(t)=σ(o(t)),此处**函数一般为softmax。
-
模型损失:L=∑t=1TL(t)
RNN反向传播过程中,需要计算U,V,W等参数的梯度,以W的梯度表达式为例,
∂W∂L=t=1∑T∂y(T)∂L∂o(T)∂y(T)∂h(T)∂o(T)∂h(t)∂h(T)∂W∂h(t)
现在需要重点计算∂h(t)∂h(T)部分,展开得到,
∂h(t)∂h(T)=∂h(T−1)∂h(T)∂h(T−2)∂h(T−1)...∂h(t)∂h(t+1)=k=t+1∏T∂h(k−1)∂h(k)=k=t+1∏Ttanh′(z(k))W
那么W的梯度表达式也就是,
∂W∂L=t=1∑T∂y(T)∂L∂o(T)∂y(T)∂h(T)∂o(T)(k=t+1∏T∂h(k−1)∂h(k))∂W∂h(t)=t=1∑T∂y(T)∂L∂o(T)∂y(T)∂h(T)∂o(T)(k=t+1∏Ttanh′(z(k))W)∂W∂h(t)
其中tanh′(z(k))=diag(1−(z(k))2)≤1,随着梯度的传导,如果W的主特征值小于1,梯度便会消失,如果W的特征值大于1,梯度便会爆炸。
需要注意的是,RNN和DNN梯度消失和梯度爆炸含义并不相同。RNN中权重在各时间步内共享,最终的梯度是各个时间步的梯度和。因此,RNN中总的梯度是不会消失的,即使梯度越传越弱,也只是远距离的梯度消失。 RNN所谓梯度消失的真正含义是,梯度被近距离梯度主导,远距离梯度很小,导致模型难以学到远距离的信息。 明白了RNN梯度消失的原因之后,我们看LSTM如何解决问题的呢?
2. LSTM为什么有效?
如上图所示,为RNN门控结构,前向传播过程包括,
-
遗忘门输出:f(t)=σ(Wfh(t−1)+Ufx(t)+bf)
-
输入门输出:i(t)=σ(Wih(t−1)+Uix(t)+bi), a(t)=tanh(Wah(t−1)+Uax(t)+ba)
-
细胞状态:C(t)=C(t−1)⊙f(t)+i(t)⊙a(t)
-
输出门输出:o(t)=σ(Woh(t−1)+Uox(t)+bo), h(t)=o(t)⊙tanh(C(t))
-
预测输出:y^(t)=σ(Vh(t)+c)
RNN梯度消失的原因是,随着梯度的传导,梯度被近距离梯度主导,模型难以学习到远距离的信息。具体原因也就是∏k=t+1T∂h(k−1)∂h(k)部分,在迭代过程中,每一步∂h(k−1)∂h(k)始终在[0,1]之间或者始终大于1。
而对于LSTM模型而言,针对∂C(k−1)∂C(k)求得,
∂C(k−1)∂C(k)=∂f(k)∂C(k)∂h(k−1)∂f(k)∂C(k−1)∂h(k−1)+∂i(k)∂C(k)∂h(k−1)∂i(k)∂C(k−1)∂h(k−1)+∂a(k)∂C(k)∂h(k−1)∂a(k)∂C(k−1)∂h(k−1)+∂C(k−1)∂C(k)
具体计算后得到,
∂C(k−1)∂C(k)=C(k−1)σ′(⋅)Wf∗o(k−1)tanh′(C(k−1))+a(k)σ′(⋅)Wi∗o(k−1)tanh′(C(k−1))+i(k)tanh′(⋅)Wc∗o(k−1)tanh′(C(k−1))+f(t)
k=t+1∏T∂C(k−1)∂C(k)=(f(k)f(k+1)...f(T))+other
在LSTM迭代过程中,针对∏k=t+1T∂C(k−1)∂C(k)而言,每一步∂C(k−1)∂C(k)可以自主的选择在[0,1]之间,或者大于1,因为f(k)是可训练学习的。那么整体∏k=t+1T∂C(k−1)∂C(k)也就不会一直减小,远距离梯度不至于完全消失,也就能够解决RNN中存在的梯度消失问题。LSTM虽然能够解决梯度消失问题,但并不能够避免梯度爆炸问题,仍有可能发生梯度爆炸。但是,由于LSTM众多门控结构,和普通RNN相比,LSTM发生梯度爆炸的频率要低很多。梯度爆炸可通过梯度裁剪解决。
LSTM遗忘门值可以选择在[0,1]之间,让LSTM来改善梯度消失的情况。也可以选择接近1,让遗忘门饱和,此时远距离信息梯度不消失。也可以选择接近0,此时模型是故意阻断梯度流,遗忘之前信息。更深刻理解可参考LSTM如何来避免梯度弥散和梯度爆炸?中回答。