【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)

在前面学习的循环网路中,因为梯度中有参数weight_hh的k次幂的存在,所以会导致梯度弥散和梯度爆炸的问题。对于梯度爆炸问题,可以用PyTorch笔记22最后面给出的梯度裁剪的方式解决。但是梯度弥散的问题没法这样直接解决,LSTM一定程度上解决了这样的问题,从而为长序列记忆提供了较好的解决方案。

长序列难题

在原始的循环网络中,实际上能处理的记忆信息比较短。如对自然语言的处理中,只能记住之前较少的几个单词的语境信息。例如"The clouds are in the sky"其中可以由"clouds"预测出"sky",它们之间的时刻比较接近。
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)
但是循环网络很难从"I grew up in France… I speak fluent"预测出下一个词是"French",因为中间还有太多的单词,这就是长序列难题。LSTM就可以更好的处理长序列问题,其中STM三个字母就表示在循环网络中的记忆单元Short-Term Memory,它表示只能做短期的记忆,而LSTM的含义就是把记忆单元的短记忆延长了,所以前面加个单词Long。
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)

回顾循环网络的结构

在前面学的循环网络中,只是单纯的将上次处理完的记忆单元和当前输入经线性变换后加在一起,再直接用Tanh反曲正切**:
ht=Tanh(Whhht1+Wthxt+b) h_t = Tanh(W_{hh} \cdot h_{t-1}+W_{th} \cdot x_t+b)

【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)
它还可以表示成之前的隐含单元hith_{i-t}和当前输入xtx_t经组合后由一个大的线性变换矩阵处理,再进行Tanh**:
ht=Tanh(W[ht1,xt]+b) h_t = Tanh(W \cdot [h_{t-1},x_t]+b)

这个表达方式对后面LSTM的前向计算描述很重要。

LSTM的门控思想

在数字电路中,门只有0和1两种状态。直观来看,LSTM的门控也是将信息有目的的过滤,为了取0倍到1倍之间的连续值,采用sigmoid值来和信息进行element-wise相乘。而这样的门控机制在旧的记忆信息、新输入的信息、输出信息时都要做。在图中可以看到圈出来的σ\sigma的地方就是门控的地方:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)
而门控的sigmoid值作为流量控制量——开度,显然也需要由网络自己学习得到,所以sigmoid的输入设计成网络此时的状态即xtx_tht1h_{t-1}的组合变换:
σ=sigmoid(W[ht1,xt]+b) \sigma = sigmoid(W \cdot [h_{t-1},x_t]+b)

注意,图中可以看到有两条水平的、沿着时间轴传递的通道,其中上面一条传递的是C,它才是LSTM中的"记忆",而下面传递的是是循环网络中也有的h,它是一种"隐含状态"的表示,同时也是LSTM的输出。

遗忘门(Forget gate)

因为σ\sigma越大乘下来之后信息保留的就越多,遗忘门实际上应该叫"记忆门"更符合语义一些。门的开度还是(三个门公式都是一样的,但参数互不影响):
ft=sigmoid(Wf[ht1,xt]+bf) f_t = sigmoid(W_f \cdot [h_{t-1},x_t]+b_f)

遗忘门控制的的是上一次层传进来的的记忆信息Ct1C_{t-1},如图所示:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)

输入门(Input gate)

输入门处理的是这一层输入的"Cell State",并不是单纯的处理输入xtx_t,而是处理像循环网络中的和隐藏单元聚合后做Tanh**后的状态信息:
Ct~=Tanh(WC[ht1,xt]+bC) \tilde{C_t} = Tanh(W_C \cdot [h_{t-1},x_t]+b_C)

输入门开度的计算还是:
it=sigmoid(Wi[ht1,xt]+bi) i_t = sigmoid(W_i \cdot [h_{t-1},x_t]+b_i)

控制过程如图:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)

新的记忆信息的计算

将当前的遗忘门开度ftf_t作用在上一层的记忆信息Ct1C_{t-1}上,将当前的输入门开度iti_t作用在当前状态信息Ct~\tilde{C_t}上,然后将它们相加,即得到当前Cell的记忆信息CtC_t
Ct=ftCt1+itCt~ C_t = f_t \cdot C_{t-1} + i_t \cdot \tilde{C_t}

如图所示:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)

输出门(Output gate)

之所以先说记忆CtC_t的计算,是因为输出门是建立在记忆计算完成的基础上的,具体地,是将记忆CtC_t进行Tanh**之后,再用输出门oto_t对其进行限制,得到本时刻的输出(即隐含状态)hth_t
ht=otTanh(Ct) h_t = o_t \cdot Tanh(C_t)

其中输出门开度的计算同样是:
ot=sigmoid(Wo[ht1,xt]+bo) o_t = sigmoid(W_o \cdot [h_{t-1},x_t]+b_o)

如图所示:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)
注意,图中hth_t除了向右输出还向上输出的原因是LSTM也可以有多层,当有多层时,本层的输入就是上一层对应时刻输出的hth_t。另外要注意,与之不同的是,记忆CtC_t只能沿着时间线横向传播。

总结

LSTM整个前向过程比较简洁的公式表示:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)


对LSTM门控机制的直观而极端的理解(说这是一种极端的理解,因为门是取连续值的,而不是像数字电路一样取离散值):
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)
当输入门关闭,遗忘门全开时,即是完全取用上一时刻的记忆不变。

当输入门和遗忘门都全开时,即是将上一时刻的记忆加到这一时刻的状态上,完全综合两者信息。

当输入门和遗忘门都完全关闭时,即是不取用任何信息,相当于在这一时刻"失忆"+"关闭一切感官。

当输入门全开,遗忘门完全关闭时,即是完全依靠现有的信息和隐含状态,而丢掉之前的记忆信息。


最后是关于为什么LSTM能解决梯度弥散的问题。循环网络中会发生梯度弥散,是因为相邻时刻的梯度是这样的形式:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)
其中WRW_R就是之前学的WhhW_hh,所以层数多了之后,链式法则相乘会有一堆WhhW_{hh}乘在一起:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)
而循环网络完全没法保证这些WhhW_hh不会都小于1,也就没法避免梯度弥散了。

对于LSTM而言,相邻时刻的梯度是这样的:
【DL学习笔记】4:长短期记忆网络(Long Short-Term Memory)
可以看到是若干项相加的形式,这样加起来仍然小于1的概率就小了很多了,从而不容易发生梯度弥散。