小袁讲长短期记忆网络(LSTM)

一, 什么是长短期

LSTM全名“ Long Short-term Memory”,中文名翻译为长短期记忆网络。小袁我刚接触这个网络的时候,一度以为长短期记忆网络既可以建模序列问题中的长期时间依赖,又可以有效地捕捉到序列数据的短期时间依赖,因而被命名为长短期记忆网络。事实上这样理解对也不对,对在LSTM确实既有捕捉序列数据的长的时间依赖,又有捕捉短的时间依赖的特性上。不对在LSTM的特性并不像我们通俗理解的长短期。英文表达而言就是“Long Short-term Memory” 和 “Long Short term Memory”的差别吧。这篇博客我会重点讲下我对“长短期”的理解,如有不正确的地方还望各位不吝指教!

注:本博客部分图片公式来源于网络,侵删。转载请注明出处!

1.1 为何会有LSTM

据各路文献博客所言,LSTM的提出是为了解决循环神经网络(RNN)无法捕捉序列的长期时间依赖的不足,RNN的核心状态更新公式为
ht=f(Wixt+Whht1) h_t=f(W^ix_t + W^hh_{t-1})
其中, hth_t 为RNN网络的隐藏层在时刻 tt 的状态值,f()f() 为RNN网络的**函数,通常为tanhtanh函数。

RNN的一种网络拓扑结构如下图所示:
小袁讲长短期记忆网络(LSTM)
由于第tt 时刻block内(上图中的绿框)的输入仅为上一时刻t1t-1 的状态值 ht1h_{t-1} 和当前时刻的输入 xtx_t ,因而RNN无法捕捉到序列数据的长期依赖,仅能捕捉到序列数据的短期依赖,这导致了RNN网络在建模上的天然不足。

事实上,对RNN的这种理解是不对的。 这种有失偏颇的理解会进一步给自己理解LSTM带来困难。上述理解主要问题在于第tt 时刻block内(上图中的绿框)的输入之一ht1h_{t-1}不是一个独立的变量,它的值通过 ht2h_{t-2}xt1x_{t-1} 计算得到(即 ht1h_{t-1} 包含 ht2h_{t-2} 的特征信息 )。递归地,hth_t 包含t=1,2,...,t1t=1,2,...,t-1 的所有隐藏层的状态特征,因而RNN事实上是有建模长期时间依赖的能力的。既然如此,那为何会有RNN无法捕捉长期的序列时间依赖关系的说法呢?所谓无风不起浪啊。事实上,这个可以用“理想很丰满,现实很骨干”来比喻。尽管RNN能够完美的建模序列数据的长期依赖关系,但是它没法用啊,因为传统的RNN非常容易陷入梯度消失或梯度爆炸问题,这导致了RNN网络在实际使用中,无法捕捉到序列的长期依赖关系。事实上相应的长短期记忆网络LSTM也是因为它在实际应用中能够巧妙地避免梯度消失或梯度爆炸问题,使得它能够捕捉到长期的序列时间依赖关系。简言之,LSTM的提出是为了克服在实际应用中 ,RNN建模的长期时间依赖关系无法通过梯度优化的不足。

1.2 谈谈RNN的梯度消失和梯度爆炸

关于RNN的梯度消失和梯度爆炸问题,参考了知乎文章 ,并结合评论和我的理解做了部分修正。具体细节如下所示:

定义参数优化的损失函数
L=t=0TLt L=\sum_{t=0}^{T}L_{t}
则损失函数LL 对参数矩阵WW 的偏导数为
LW=t=0TLtW \frac{\partial L}{\partial W} =\sum_{t=0}^{T}\frac{\partial L_t}{\partial W}
现考虑tt 时刻的损失函数误差对输出矩阵WoW^o ,隐藏层矩阵WhW^h,输入矩阵WiW^i 的偏导数,它们依次为
LtWo=LtytytWo \frac{\partial L_t}{\partial W^o} =\frac{\partial L_t}{\partial y_t}\frac{\partial y_t}{\partial W^o}

LtWh=k=0tLtytytht(i=k+1thihi1)hkWh \frac{\partial L_t}{\partial W^h} =\sum_{k=0}^{t}\frac{\partial L_t}{\partial y_t}\frac{\partial y_t}{\partial h_t}(\prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}})\frac{\partial h_{k}}{\partial W^h}

LtWi=k=0tLtytytht(i=k+1thihi1)hkWi \frac{\partial L_t}{\partial W^i} =\sum_{k=0}^{t}\frac{\partial L_t}{\partial y_t}\frac{\partial y_t}{\partial h_t}(\prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}})\frac{\partial h_{k}}{\partial W^i}

可以看到,在修正某个时刻tt 的误差时,离时刻tt越久远的时刻kk需要考虑到的隐藏层之间的偏导数hihi1\frac{\partial h_i}{\partial h_{i-1}} 的连乘次数越多。为了方便理解,我们假设xtx_thth_t 均为一维变量,则WhW_hWiW_i均为一维变量。 现在我们考虑下i=k+1thihi1\prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}} ,由RNN的定义公式ht=f(Wixt+Whht1)h_t=f(W^ix_t + W^hh_{t-1})
htht1=Whf \frac{\partial h_t}{\partial h_{t-1}}=W_hf'
因为ff 为sigmoid函数,所以它的导数的上届为0.25,所以有
htht1=Whf0.25Wh \frac{\partial h_t}{\partial h_{t-1}}=W_hf'\leq0.25W_h
如果Wh4W_h\leq4,则恒有htht1\frac{\partial h_t}{\partial h_{t-1}}小于1。此时,若时刻tt与时刻kk的时差较大,则i=k+1thihi1\prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}}趋于0。此时发生梯度消失现象。

或者某种情况下Whf1W_hf'\geq1,即htht1\frac{\partial h_t}{\partial h_{t-1}}大于1。此时,若时刻tt与时刻kk的时差较大,则i=k+1thihi1\prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}}趋于无穷大。此时发生梯度爆炸现象。

既然如此,那我可否在初始化的时候选择一个好的WhW_h,使得刚好不会发生梯度消失和爆炸呢?事实上,一个好的初始化确实可以避免迭代算法在开始时更可能避免梯度消失和爆炸,然而随着迭代次数的增加,更新后的WhW_h就无法保证了,更多详细资料参考博客

1.3 LSTM的基本组成

如下给出了两种LSTM的框架单元表示图,在此我们不去细究图中每个变量的含义,有兴趣的伙伴参考李宏毅老师的教学视频。针对LSTM网络我们选择从公式出发去介绍LSTM。
小袁讲长短期记忆网络(LSTM)
首先我们给出几个基本的符号定义:

σσ表示sigmoid函数

hth_t表示tt时刻隐藏层的状态值

CtC_t表示tt时刻细胞层的状态值

ot,ft,ito_t, f_t, i_t依次表示tt时刻输出门,遗忘门和输入门的状态值

xt,ytx_t, y_t表示tt时刻网络的输入和输出

WRW_R表示不同的网络权重(不同的权重用不同下标表示)

对于tt时刻的网络的block(上图中的单个绿框),其信号输入为xtx_tt1t-1时刻的block的隐藏层输出ht1h_{t-1}和细胞层输出Ct1C_{t-1};信号输出为ht,Cth_t, C_tCtC_t的计算公式如下所示
Ct=ftCt1+itCt~ C_t=f_tC_{t-1}+i_t\tilde{C_t}

Ct~=tanh(Wc[ht1,xt]) \tilde{C_t}=tanh(W_c[h_{t-1},x_t])

hth_t的计算公式如下所示
ht=ottanh(Ct) h_t=o_ttanh(C_t)
其中,tt时刻输出门,遗忘门和输入门的状态值的计算公式如下所示
ot=σ(Wo[ht1,xt]) o_t = σ(W_{o}[h_{t-1},x_t])

ft=σ(Wf[ht1,xt]) f_t = σ(W_{f}[h_{t-1},x_t])

it=σ(Wi[ht1,xt]) i_t = σ(W_{i}[h_{t-1},x_t])

需要强调的是,现有的一些有关LSTM的框架流程图只能宏观的表示网络的输入输出和大致的流程,于小袁我而言这些流程图对于LSTM的刻画程度并没有公式来的直接和具体,因而小袁还是建议感兴趣的伙伴可以多多钻研钻研公式。

二,对LSTM的两脸懵逼

2.1 懵逼一:这个结构得多大脑洞想的

在上面的讲解中我们已经知道,RNN在梯度更新权重的过程中存在梯度消失问题,那LSTM网络就和小葵花妈妈一样,自然是哪里有问题改哪里。1997年Sepp Hochreiter在提出长短期记忆网络LSTM时,网络中的遗忘门的值ft=1f_t = 1,在这篇论文中,作者指出设计输入门和输出门的原因主要是为了解决冲突,原文如下:

  1. Input weight conflict: for simplicity, let us focus on a single additional input weight wjiw_{ji}. Assume that the total error can be reduced by switching on unit jj in response to a certain input, and keeping it active for a long time (until it helps to compute a desired output). Provided ii is non- zero, since the same incoming weight has to be used for both storing certain inputs and ignoring others, wjiw_{ji} will often receive conflicting weight update signals during this time (recall that jj is linear): these signals will attempt to make wjiw_{ji} participate in (1) storing the input (by switching on jj) and (2) protecting the input (by preventing jj from being switched off by irrelevant later inputs). This conflict makes learning difficult, and calls for a more context-sensitive mechanism for controlling “write operations” through input weights.
  2. Output weight conflict: assume jj is switched on and currently stores some previous input. For simplicity, let us focus on a single additional outgoing weight wkjw_{kj} . The same wkjw_{kj} has to be used for both retrieving jj 's content at certain times and preventing jj from disturbing kk at other times. As long as unit jj is non-zero, wkjw_{kj} will attract conflicting weight update signals generated during sequence processing: these signals will attempt to make wkjw_{kj} participate in (1) accessing the information stored in jj and — at different times — (2) protecting unit kk from being perturbed by jj . For instance, with many tasks there are certain “short time lag errors” that can be reduced in early training stages. However, at later training stages jj may suddenly start to cause avoidable errors in situations that already seemed under control by attempting to participate in reducing more difficult “long time lag errors”. Again, this conflict makes learning difficult, and calls for a more context-sensitive mechanism for controlling “read operations” through output weights.

就小袁个人理解而言,输入门和输出门是一种对信息的筛选机制,比如阻止t1t-1t2t-2时刻的网络的输入xt1x_{t-1}xt2x_{t-2}tt时刻的细胞状态值CtC_t的直接影响,则我只需要简单将输入门it1i_{t-1}it2i_{t-2}置零。在此,我们不妨假设在任何时刻,如果LSTM的ot,ito_t,i_t恒为1,ftf_t恒为1,此时的网络称为退化的LSTM。有:

对于tt时刻的网络的block(上图中的单个绿框),其信号输入为xtx_tt1t-1时刻的block的隐藏层输出ht1h_{t-1}和细胞层输出Ct1C_{t-1};信号输出为ht,Cth_t, C_tCtC_t的计算公式如下所示
Ct=Ct1+Ct~ C_t=C_{t-1}+\tilde{C_t}

Ct~=tanh(Wc[ht1,xt]) \tilde{C_t}=tanh(W_c[h_{t-1},x_t])

hth_t的计算公式如下所示
ht=tanh(Ct) h_t=tanh(C_t)
联合上述三个公式易知
ht=tanh(Ct1+tanh(Wc[ht1,xt])) h_t=tanh(C_{t-1}+tanh(W_c[h_{t-1},x_t]))
对比此时退化的LSTM神经网络结构和RNN更新方程
ht=tanh(Wixt+Whht1) h_t=tanh(W^ix_t + W^hh_{t-1})
可以看到相比RNN的结构而言,退化的LSTM引进了细胞的状态层,相比RNN,网络的单个block深了一层。

2.2 懵逼二: 咋就避免序列数据的梯度消失

如1.2所言,递归导数是导致梯度消失的主要因素,因此我们分析下在LSTM中,递归导数的数学表达。首先,LSTM的细胞的状态值更新公式为:
Ct=ftCt1+itCt~ C_t=f_tC_{t-1}+i_t\tilde{C_t}
又在上述公式中,ft,it,Ct~f_t,i_t,\tilde{C_t}是关于隐藏状态值ht1h_{t-1}的函数,ht1h_{t-1}是关于Ct1C_{t-1}的函数,因此根据链式求导法则,有
CtCt1=Ctftftht1ht1Ct1+CtCt1+Ctititht1ht1Ct1+CtC~tC~tht1ht1Ct1 \frac{\partial C_t}{\partial C_{t-1}}=\frac{\partial C_t}{\partial f_{t}}\frac{\partial f_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}}+\frac{\partial C_t}{\partial i_{t}}\frac{\partial i_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_t}{\partial \tilde C_{t}}\frac{\partial\tilde C_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}
化简有
CtCt1=Ct1σWfot1tanh(Ct1)+ft+C~tσWiot1tanh(Ct1)+ittanhWcot1tanh(Ct1) \frac{\partial C_t}{\partial C_{t-1}}=C_{t-1} σ'W_f*o_{t-1}tanh'(C_{t-1})+f_t+\tilde C_tσ'W_i*o_{t-1}tanh'(C_{t-1})+i_{t}tanh'W_c*o_{t-1}tanh'(C_{t-1})
我们现在将RNN的递归导数列出来,如下所示
htht1=Whf \frac{\partial h_t}{\partial h_{t-1}}=W_hf'
容易看到,LSTM的递归导数的值的大小与时间tt有关,即不同时刻的值可以大于1,或者在0~1区间。然而在RNN中,一旦Wh<4W_h<4(假设WhW_h为维度为1),则所有时刻的递归导数的值均小于1.这就使得RNN相比LSTM更易发生梯度消失问题。用weberna的博客的话说:“ In vanilla RNNs, the termshtht1\frac{\partial h_t}{\partial h_{t-1}} will eventually take on a values that are either always above 1 or always in the range [0,1], this is essentially what leads to the vanishing/exploding gradient problem. The terms here, CtCt1\frac{\partial C_t}{\partial C_{t-1}} ,at any time step can take on either values that are greater than 1 or values in the range [0,1]. Thus if we extend to an infinite amount of time steps, it is not guarented that we will end up converging to 0 or infinity (unlike in vanilla RNNs).”

也就是说,LSTM并不能保证能完全避免梯度消失,只是相比与RNN,递归导数中的ft,ot,itf_t,o_t,i_t的值是由数据驱动的,可调整的,因而更容易避免梯度迭代优化算法中的梯度消失问题。

三,博主碎碎念

在LSTM的学习理解过程中,博主觉得比较好的三个学习链接,推荐给大家:

如果你对LSTM的设计思路感兴趣,后者你有其它的understanding或idea,欢迎来私戳博主交流。

博主目前是枚科研秃头怪,如果你也撰写博客或者看到过一些经典算法的好博客,能推荐给我的话我将非常感谢!