LSTM如何解决RNN带来的梯度消失问题

本篇文章参考于 RNN梯度消失和爆炸的原因Towser关于LSTM如何来避免梯度弥散和梯度爆炸?的问题解答、Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass

看本篇文章之前,建议自行学习RNN和LSTM的前向和反向传播过程,学习教程可参考刘建平老师博客循环神经网络(RNN)模型与前向反向传播算法LSTM模型与前向反向传播算法

具体了解LSTM如何解决RNN所带来的梯度消失问题之前,我们需要明白为什么RNN会带来梯度消失问题。

1. RNN梯度消失原因

LSTM如何解决RNN带来的梯度消失问题
如上图所示,为RNN模型结构,前向传播过程包括,

  • 隐藏状态:h(t)=σ(z(t))=σ(Ux(t)+Wh(t1)+b)h^{(t)} = \sigma (z^{(t)}) = \sigma(Ux^{(t)} + Wh^{(t-1)} + b),此处**函数一般为tanhtanh
  • 模型输出:o(t)=Vh(t)+co^{(t)} = Vh^{(t)} + c
  • 预测输出:y^(t)=σ(o(t))\hat{y}^{(t)} = \sigma(o^{(t)}),此处**函数一般为softmax。
  • 模型损失:L=t=1TL(t)L = \sum_{t = 1}^{T} L^{(t)}

RNN反向传播过程中,需要计算U,V,WU, V, W等参数的梯度,以WW的梯度表达式为例,
LW=t=1TLy(T)y(T)o(T)o(T)h(T)h(T)h(t)h(t)W \frac{\partial L}{\partial W} = \sum_{t = 1}^{T} \frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}} \frac{\partial h^{(T)}}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial W}

现在需要重点计算h(T)h(t)\frac{\partial h^{(T)}}{\partial h^{(t)}}部分,展开得到,
h(T)h(t)=h(T)h(T1)h(T1)h(T2)...h(t+1)h(t)=k=t+1Th(k)h(k1)=k=t+1Ttanh(z(k))W \frac{\partial h^{(T)}}{\partial h^{(t)}} = \frac{\partial h^{(T)}}{\partial h^{(T-1)}} \frac{\partial h^{(T - 1)}}{\partial h^{(T-2)}} ...\frac{\partial h^{(t+1)}}{\partial h^{(t)}} = \prod_{k=t + 1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k - 1)}} = \prod_{k=t+1}^{T} tanh^{'}(z^{(k)}) W

那么WW的梯度表达式也就是,
LW=t=1TLy(T)y(T)o(T)o(T)h(T)(k=t+1Th(k)h(k1))h(t)W=t=1TLy(T)y(T)o(T)o(T)h(T)(k=t+1Ttanh(z(k))W)h(t)W \frac{\partial L}{\partial W} = \sum_{t = 1}^{T} \frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}} \left( \prod_{k=t + 1}^{T} \frac{\partial h^{(k)}}{\partial h^{(k - 1)}} \right) \frac{\partial h^{(t)}}{\partial W} \\ = \sum_{t = 1}^{T} \frac{\partial L}{\partial y^{(T)}} \frac{\partial y^{(T)}}{\partial o^{(T)}} \frac{\partial o^{(T)}}{\partial h^{(T)}} \left( \prod_{k=t+1}^{T} tanh^{'}(z^{(k)}) W \right) \frac{\partial h^{(t)}}{\partial W} \\

其中tanh(z(k))=diag(1(z(k))2)1tanh^{'}(z^{(k)}) = diag(1-(z^{(k)})^2) \leq 1,随着梯度的传导,如果WW的主特征值小于1,梯度便会消失,如果W的特征值大于1,梯度便会爆炸。

需要注意的是,RNN和DNN梯度消失和梯度爆炸含义并不相同。RNN中权重在各时间步内共享,最终的梯度是各个时间步的梯度和。因此,RNN中总的梯度是不会消失的,即使梯度越传越弱,也只是远距离的梯度消失。 RNN所谓梯度消失的真正含义是,梯度被近距离梯度主导,远距离梯度很小,导致模型难以学到远距离的信息。 明白了RNN梯度消失的原因之后,我们看LSTM如何解决问题的呢?

2. LSTM为什么有效?

LSTM如何解决RNN带来的梯度消失问题

如上图所示,为RNN门控结构,前向传播过程包括,

  • 遗忘门输出:f(t)=σ(Wfh(t1)+Ufx(t)+bf)f^{(t)} = \sigma(W_fh^{(t-1)} + U_fx^{(t)} + b_f)

  • 输入门输出:i(t)=σ(Wih(t1)+Uix(t)+bi)i^{(t)} = \sigma(W_ih^{(t-1)} + U_ix^{(t)} + b_i), a(t)=tanh(Wah(t1)+Uax(t)+ba)a^{(t)} = tanh(W_ah^{(t-1)} + U_ax^{(t)} + b_a)

  • 细胞状态:C(t)=C(t1)f(t)+i(t)a(t)C^{(t)} = C^{(t-1)}\odot f^{(t)} + i^{(t)}\odot a^{(t)}

  • 输出门输出:o(t)=σ(Woh(t1)+Uox(t)+bo)o^{(t)} = \sigma(W_oh^{(t-1)} + U_ox^{(t)} + b_o), h(t)=o(t)tanh(C(t))h^{(t)} = o^{(t)}\odot tanh(C^{(t)})

  • 预测输出:y^(t)=σ(Vh(t)+c)\hat{y}^{(t)} = \sigma(Vh^{(t)}+c)

RNN梯度消失的原因是,随着梯度的传导,梯度被近距离梯度主导,模型难以学习到远距离的信息。具体原因也就是k=t+1Th(k)h(k1)\prod_{k=t+1}^{T}\frac{\partial h^{(k)}}{\partial h^{(k - 1)}}部分,在迭代过程中,每一步h(k)h(k1)\frac{\partial h^{(k)}}{\partial h^{(k - 1)}}始终在[0,1]之间或者始终大于1。

而对于LSTM模型而言,针对C(k)C(k1)\frac{\partial C^{(k)}}{\partial C^{(k-1)}}求得,

C(k)C(k1)=C(k)f(k)f(k)h(k1)h(k1)C(k1)+C(k)i(k)i(k)h(k1)h(k1)C(k1)+C(k)a(k)a(k)h(k1)h(k1)C(k1)+C(k)C(k1) \frac{\partial C^{(k)}}{\partial C^{(k-1)}} = \frac{\partial C^{(k)}}{\partial f^{(k)}} \frac{\partial f^{(k)}}{\partial h^{(k-1)}} \frac{\partial h^{(k-1)}}{\partial C^{(k-1)}} + \frac{\partial C^{(k)}}{\partial i^{(k)}} \frac{\partial i^{(k)}}{\partial h^{(k-1)}} \frac{\partial h^{(k-1)}}{\partial C^{(k-1)}} \\+ \frac{\partial C^{(k)}}{\partial a^{(k)}} \frac{\partial a^{(k)}}{\partial h^{(k-1)}} \frac{\partial h^{(k-1)}}{\partial C^{(k-1)}} + \frac{\partial C^{(k)}}{\partial C^{(k-1)}}

具体计算后得到,
C(k)C(k1)=C(k1)σ()Wfo(k1)tanh(C(k1))+a(k)σ()Wio(k1)tanh(C(k1))+i(k)tanh()Wco(k1)tanh(C(k1))+f(t) \frac{\partial C^{(k)}}{\partial C^{(k-1)}} = C^{(k-1)}\sigma^{'}(\cdot)W_f*o^{(k-1)}tanh^{'}(C^{(k-1)}) \\ + a^{(k)}\sigma^{'}(\cdot)W_i*o^{(k-1)}tanh^{'}(C^{(k-1)}) \\ + i^{(k)}tanh^{'}(\cdot)W_c*o^{(k-1)}tanh^{'}(C^{(k-1)}) \\ + f^{(t)}

k=t+1TC(k)C(k1)=(f(k)f(k+1)...f(T))+other \prod _{k=t+1}^{T} \frac{\partial C^{(k)}}{\partial C^{(k-1)}} = (f^{(k)}f^{(k+1)}...f^{(T)}) + other

在LSTM迭代过程中,针对k=t+1TC(k)C(k1)\prod _{k=t+1}^{T} \frac{\partial C^{(k)}}{\partial C^{(k-1)}}而言,每一步C(k)C(k1)\frac{\partial C^{(k)}}{\partial C^{(k-1)}}可以自主的选择在[0,1]之间,或者大于1,因为f(k)f^{(k)}是可训练学习的。那么整体k=t+1TC(k)C(k1)\prod _{k=t+1}^{T} \frac{\partial C^{(k)}}{\partial C^{(k-1)}}也就不会一直减小,远距离梯度不至于完全消失,也就能够解决RNN中存在的梯度消失问题。LSTM虽然能够解决梯度消失问题,但并不能够避免梯度爆炸问题,仍有可能发生梯度爆炸。但是,由于LSTM众多门控结构,和普通RNN相比,LSTM发生梯度爆炸的频率要低很多。梯度爆炸可通过梯度裁剪解决。

LSTM遗忘门值可以选择在[0,1]之间,让LSTM来改善梯度消失的情况。也可以选择接近1,让遗忘门饱和,此时远距离信息梯度不消失。也可以选择接近0,此时模型是故意阻断梯度流,遗忘之前信息。更深刻理解可参考LSTM如何来避免梯度弥散和梯度爆炸?中回答。