LSTM如何解决梯度消失与梯度爆炸

LSTM如何解决梯度消失与梯度爆炸
  这是一张经典的LSTM示意图,LSTM依靠 ftf_titi_toto_t来控制输入输出,ft=σ(Wf[ht1,xt]+bf)f_{t}=\sigma\left(W_{f} \cdot\left[h_{t-1}, x_{t}\right]+b_{f}\right)it=σ(Wi[ht1,xt]+bi)i_{t}=\sigma\left(W_{i} \cdot\left[h_{t-1}, x_{t}\right]+b_{i}\right)ot=σ(Wo[ht1,xt]+bo)o_{t}=\sigma\left(W_{o}\left[h_{t-1}, x_{t}\right]+b_{o}\right)
  我们将其简化为:ft=σ(WfXt+bf)f_{t}=\sigma\left(W_{f} X_{t}+b_{f}\right)it=σ(WiXt+bi)i_{t}=\sigma\left(W_{i} X_{t}+b_{i}\right)oi=σ(WoXt+bo)o_{i}=\sigma\left(W_{o} X_{t}+b_{o}\right)
  当前的状态 St=ftSt1+itXtS_{t}=f_{t} S_{t-1}+i_{t} X_{t} 类似与传统RNN  St=WsSt1+WxXt+b1S_{t}=W_{s} S_{t-1}+W_{x} X_{t}+b_{1} 。将LSTM的状态表达式展开后得:St=σ(WfXt+bf)St1+σ(WiXt+bi)XtS_{t}=\sigma\left(W_{f} X_{t}+b_{f}\right) S_{t-1}+\sigma\left(W_{i} X_{t}+b_{i}\right) X_{t}  如果加上**函数St=tanh[σ(WfXt+bf)St1+σ(WiXt+bi)Xt]S_{t}=\tanh \left[\sigma\left(W_{f} X_{t}+b_{f}\right) S_{t-1}+\sigma\left(W_{i} X_{t}+b_{i}\right) X_{t}\right]  RNN梯度消失和爆炸的原因这篇文章中传统RNN求偏导的过程包含:j=k+1tSjSj1=j=k+1ttanhWs\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} W_{s}  对于LSTM同样也包含这样的一项,但是在LSTM中:j=k+1tSjSj1=j=k+1ttanhσ(WfXt+bf)\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} \sigma\left(W_{f} X_{t}+b_{f}\right) 假设  Z=tanh(x)σ(y)Z=\tanh ^{\prime}(x) \sigma(y),则ZZ的函数图像如下图所示:

LSTM如何解决梯度消失与梯度爆炸
  可以看到该函数值基本上不是0就是1。
  传统RNN的求偏导过程:L3Ws=L3O3O3S3S3Ws+L3O3O3S3S3S2S2Ws+L3O3O3S3S3S2S2S1S1Ws\frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial S_{1}} \frac{\partial S_{1}}{\partial W_{s}}
  在LSTM中为:L3Ws=L3O3O3S3S3Ws+L3O3O3S3S2Ws+L3O3O3S3S1Ws\frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{1}}{\partial W_{s}}
  因为j=k+1tSjSj1=j=k+1ttanhσ(WfXt+bf)01\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} \sigma\left(W_{f} X_{t}+b_{f}\right) \approx 0 | 1
  这样就解决了传统RNN中梯度消失的问题。