LSTM原理详解


回顾RNN详解中,RNN的缺陷是无法做到长期依赖。为此我们引入LSTM(Long Short Term Memory networks(以下简称LSTM)),是一种特殊的RNN,主要是为了解决长期依赖问题。同时,介绍一种LSTM的变体GRU,简化了LSTM,提高运算速度。

LSTM引例

先来看这样一个例子:

LSTM原理详解

我们希望RNN可以学习到喝咖啡和打王者荣耀之间的依赖关系,早上喝了咖啡,下午才有精力打王者荣耀,但是二者在时间上并不接近。如何把这个长时间依赖关系表达出来?原始RNN中理论上是可以将其表达出来,但由于上述所提到RNN的缺陷,原始RNN很难把这个依赖学习到。同样后边还有一个长时间依赖:中午如果打过王者荣耀,那么吃完晚饭就不打了,直接睡觉。这种长时间的依赖在序列数据中是很常见的,而LSTM可以很容易的学习到这种依赖。

LSTM原理

标准的RNN如下所示:

LSTM原理详解

在标准RNN的结构上加了点东西,整体结构如下所示:

LSTM原理详解

LSTM由三个门来控制细胞状态,这三个门分别称为忘记门、输入门和输出门。输入门控制当前计算的新状态以及以多大程度更新到记忆单元中;遗忘门控制前一步记忆单元中的信息以多大程度被遗忘掉;输出门控制当前的输出有多大程度取决于当前的记忆单元。接下来依次介绍:

遗忘门

主要决定决定细胞状态需要丢弃哪些信息。通过查看h(t1)h_{(t-1)}xtx_t的信息来输出一个0-1之间的向量,该向量中的数值表示状态Ct1C_{t-1} 中有多少信息保留或丢弃,0表示不保留,1表示都保留,遗忘门如下图所示:

LSTM原理详解
ft=σ(Wf[ht1,xt]+bf) \mathbf{f}_t=\sigma(W_f\cdot[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf{b}_f)
式中:WfW_f是遗忘门的权重矩阵,[ht1,xt][\mathbf{h}_{t-1},\mathbf{x}_t]表示把两个向量连接成一个更长的向量,bfb_f是遗忘门的偏置项,σ\sigma 是sigmoid函数。

其中Wf[ht1,xt]W_f\cdot[\mathbf{h}_{t-1},\mathbf{x}_t] 可以理解为:
[Wf][ht1xt]=[WfhWfx][ht1xt]=Wfhht1+Wfxxt \begin{aligned} \begin{bmatrix}W_f\end{bmatrix}\begin{bmatrix}\mathbf{h}_{t-1}\\ \mathbf{x}_t\end{bmatrix}&= \begin{bmatrix}W_{fh}&W_{fx}\end{bmatrix}\begin{bmatrix}\mathbf{h}_{t-1}\\ \mathbf{x}_t\end{bmatrix}\\ &=W_{fh}\mathbf{h}_{t-1}+W_{fx}\mathbf{x}_t \end{aligned}

输入门

主要决定给细胞状态添加哪些新的信息。输入门如下图所示:

LSTM原理详解

(1)利用ht1\boldsymbol h_{t-1}xt\boldsymbol x_t通过一个称为输入门的操作来决定更新哪些信息。
it=σ(Wi[ht1,xt]+bi) \mathbf{i}_t=\sigma(W_i\cdot[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf{b}_i)
(2)利过一个tanh层得到新的候选细胞信息c~t\mathbf{\tilde{c}}_t,这些信息可能会被更新到细胞信息中。
c~t=tanh(Wc[ht1,xt]+bc) \mathbf{\tilde{c}}_t=\tanh(W_c\cdot[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf{b}_c)
(3)计算当前时刻的单元状态ct\boldsymbol c_{t} ,更新的规则就是通过忘记门选择忘记旧细胞信息的一部分,通过输入门选择添加候选细胞信息的一部分得到新的细胞信息。如下图所示:

LSTM原理详解

即:
ct=ftct1+itc~t \mathbf{c}_t=f_t*{\mathbf{c}_{t-1}}+i_t*{\mathbf{\tilde{c}}_t}
我们就把LSTM关于当前的记忆c~t\mathbf{\tilde{c}}_t和长期的记忆ct1\mathbf{c}_{t-1}组合在一起,形成了新的单元状态。

由于遗忘门的控制,它可以保存很久很久之前的信息,由于输入门的控制,它又可以避免当前无关紧要的内容进入记忆。接着,来看看输出门,

输出门

将输入经过一个称为输出门的sigmoid层得到判断条件,然后将细胞状态经过tanh层得到一个-1~1之间值的向量,该向量与输出门得到的判断条件相乘就得到了最终该RNN单元的输出。输出门如下图所示:

LSTM原理详解

输出门控制了长期记忆对当前输出的影响
ot=σ(Wo[ht1,xt]+bo) \mathbf{o}_t=\sigma(W_o\cdot[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf{b}_o)
LSTM最终的输出,是由输出门和单元状态共同确定的:
ht=ottanh(ct) \mathbf{h}_t=\mathbf{o}_t* \tanh(\mathbf{c}_t)

LSTM总结

如何实现长期依赖?

在一个训练好的网络中,当输入序列没有重要信息时,LSTM遗忘门的值接近为1,输入门接近0,此时过去的记忆会被保存,从而实现了长期记忆;当输入的序列中出现了重要信息时,LSTM会将其存入记忆中,此时输入门的值会接近于1;当输入序列出现重要信息,且该信息意味着之前的记忆不再重要的时候,输入门接近1,遗忘门接近0,这样旧的记忆被遗忘,新的重要信息被记忆。经过这样的设计,整个网络更容易学习到序列之间的长期依赖。

如何避免梯度消失/爆炸?

在lstm中,状态c\mathbf c是通过累加的方式来计算的。不像RNN中的累乘的形式,这样的话,它的的导数也不是乘积的形式,这样就不会发生梯度消失的情况了。

GRU

GRU(Gated Recurrent Unit)作为LSTM的一种变体,与LSTM有两个不同点:

(1)GRU将LSTM中的两个信息流简化成一个信息流,输入只有一个ht\boldsymbol h_t

(2)GRU将忘记门和输入门合成了一个单一的更新门,还引入了一个重置门。

如下图所示:

LSTM原理详解
主要运算过程如下:
rt=σ(Wr[ht1,xt])zt=σ(Wz[ht1,xt])h~t=tanh(W[rtht1,xt])ht=(1zt)ht1+zth~t \begin{aligned} &r_t = \sigma(W_r\cdot[h_{t-1},x_t]) \\ &z_t = \sigma(W_z\cdot[h_{t-1},x_t]) \\ &\tilde h_t = \tanh(W \cdot[r_t * h_{t-1},x_t]) \\ & h_t = (1-z_t)*h_{t-1} + z_t*\tilde h_t \end{aligned}

相当于简化了LSTM,运算速度提高了很多,应用效果也没有差很多。

参考文章:

理解LSTM(通俗易懂版)

NLP面试题目汇总1-5