这是一张经典的LSTM示意图,LSTM依靠 ft、it、ot来控制输入输出,ft=σ(Wf⋅[ht−1,xt]+bf)it=σ(Wi⋅[ht−1,xt]+bi)ot=σ(Wo[ht−1,xt]+bo)
我们将其简化为:ft=σ(WfXt+bf)it=σ(WiXt+bi)oi=σ(WoXt+bo)
当前的状态 St=ftSt−1+itXt 类似与传统RNN St=WsSt−1+WxXt+b1 。将LSTM的状态表达式展开后得:St=σ(WfXt+bf)St−1+σ(WiXt+bi)Xt 如果加上**函数St=tanh[σ(WfXt+bf)St−1+σ(WiXt+bi)Xt] RNN梯度消失和爆炸的原因这篇文章中传统RNN求偏导的过程包含:j=k+1∏t∂Sj−1∂Sj=j=k+1∏ttanh′Ws 对于LSTM同样也包含这样的一项,但是在LSTM中:j=k+1∏t∂Sj−1∂Sj=j=k+1∏ttanh′σ(WfXt+bf) 假设 Z=tanh′(x)σ(y),则Z的函数图像如下图所示:
可以看到该函数值基本上不是0就是1。
传统RNN的求偏导过程:∂Ws∂L3=∂O3∂L3∂S3∂O3∂Ws∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂Ws∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂Ws∂S1
在LSTM中为:∂Ws∂L3=∂O3∂L3∂S3∂O3∂Ws∂S3+∂O3∂L3∂S3∂O3∂Ws∂S2+∂O3∂L3∂S3∂O3∂Ws∂S1
因为j=k+1∏t∂Sj−1∂Sj=j=k+1∏ttanh′σ(WfXt+bf)≈0∣1
这样就解决了传统RNN中梯度消失的问题。