【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

点击上方“机器学习与生成对抗网络”,关注"星标"

获取有趣、好玩的前沿干货!

作者 陈楠  知乎

https://zhuanlan.zhihu.com/p/83496936

编辑 机器学习算法与自然语言处理

著作权归作者,文仅分享,侵删

1.长短期记忆网络LSTM

LSTM(Long short-term memory)通过刻意的设计来避免长期依赖问题,是一种特殊的RNN。长时间记住信息实际上是 LSTM 的默认行为,而不是需要努力学习的东西!

所有递归神经网络都具有神经网络的链式重复模块。在标准的RNN中,这个重复模块具有非常简单的结构,例如只有单个tanh层,如下图所示。

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图1 RNN结构图

LSTM具有同样的结构,但是重复的模块拥有不同的结构,如下图所示。与RNN的单一神经网络层不同,这里有四个网络层,并且以一种非常特殊的方式进行交互。

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图2 LSTM结构图


【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

  1.1 LSTM--遗忘门

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图3 遗忘门

LSTM 的第一步要决定从细胞状态中舍弃哪些信息。这一决定由所谓“遗忘门层”的 S 形网络层做出。它接收 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 和 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 ,并且对细胞状态 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 中的每一个数来说输出值都介于 0 和 1 之间。1 表示“完全接受这个”,0 表示“完全忽略这个”。

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

  1.2 LSTM--输入门

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图4 输入门

下一步就是要确定需要在细胞状态中保存哪些新信息。这里分成两部分。第一部分,一个所谓“输入门层”的 S 形网络层确定哪些信息需要更新。第二部分,一个 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 形网络层创建一个新的备选值向量—— 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 ,可以用来添加到细胞状态。在下一步中我们将上面的两部分结合起来,产生对状态的更新。


【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

  1.3 LSTM--细胞状态更新

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图5 细胞状态更新

现在更新旧的细胞状态 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 更新到 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 。先前的步骤已经决定要做什么,我们只需要照做就好。

我们对旧的状态乘以 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 ,用来忘记我们决定忘记的事。然后我们加上 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 ,这是新的候选值,根据我们对每个状态决定的更新值按比例进行缩放。

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

  1.4 LSTM--输出门

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图6 输出门

最后,我们需要确定输出值。输出依赖于我们的细胞状态,但会是一个“过滤的”版本。首先我们运行 S 形网络层,用来确定细胞状态中的哪些部分可以输出。然后,我们把细胞状态输入 tanh(把数值调整到 −1 和 1 之间)再和 S 形网络层的输出值相乘,部这样我们就可以输出想要输出的分。


2. LSTM的变种以及前向、反向传播

目前所描述的还只是一个相当一般化的 LSTM 网络。但并非所有 LSTM 网络都和之前描述的一样。事实上,几乎所有文章都会改进 LSTM 网络得到一个特定版本。差别是次要的,但有必要认识一下这些变种。


【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

  2.1 带有"窥视孔连接"的LSTM

一个流行的 LSTM 变种由 Gers 和 Schmidhuber 提出,在 LSTM 的基础上添加了一个“窥视孔连接”,这意味着我们可以让门网络层输入细胞状态。

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图7 添加“窥视孔连接”的LSTM

上图中我们为所有门添加窥视孔,但许多论文只为部分门添加。为了更直观的推导反向传播算法,将上图转化为下图:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图8 转化后的窥视孔LSTM

前向传播:在t时刻的前向传播公式为:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

反向传播:对反向传播算法了解不够透彻的,请参考陈楠:反向传播算法推导过程(非常详细),这里有详细的推导过程,本文将直接使用其结论。

已知: 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 ,求某个节点梯度时,首先应该找到该节点的输出节点,然后分别计算所有输出节点的梯度乘以输出节点对该节点的梯度,最后相加即可得到该节点的梯度。如计算 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 时,找到 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 节点的所有输出节点 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 ,然后分别计算输出节点的梯度(如 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 )与输出节点对 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 的梯度的乘积(如 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 ),最后相加即可得到节点 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 的梯度:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

同理可得t时刻其它节点的梯度:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

对参数的梯度:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法


【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

  2.2 GRU

一个更有意思的 LSTM 变种称为 Gated Recurrent Unit(GRU),由 Cho 等人提出。LSTM通过三个门函数输入门、遗忘门和输出门分别控制输入值、记忆值和输出值。而GRU中只有两个门:更新门【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 和重置门 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 ,如下图所示。更新门用于控制前一时刻的状态信息被带入到当前状态中的程度,更新门的值越大说明前一时刻的状态信息带入越多;重置门控制前一时刻状态有多少信息被写入到当前的候选集 【通俗推导】人人都能看懂的LSTM介绍及反向传播算法 上,重置门越小,前一状态的信息被写入的越少。这样做使得 GRU 比标准的 LSTM 模型更简单,因此正在变得流行起来。

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图9 GRU

为了更加直观的推导反向传播公式,将上图转化为如下形式:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图10 转换后的GRU

GRU的前向传播:在t时刻的前向传播公式为:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

GRU的反向传播:t时刻其它节点的梯度:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

对参数的梯度:

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法


【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

  2.3 遗忘门与输入门相结合的LSTM

另一个变种把遗忘和输入门结合起来。同时确定要遗忘的信息和要添加的新信息,而不再是分开确定。当输入的时候才会遗忘,当遗忘旧信息的时候才会输入新数据。

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法

图11 遗忘门与输入门相结合的LSTM

前向与反向算法与上述变种相同,这里不再做过多推导。

参考资料:【翻译】理解 LSTM 网络 - xuruilong100 - 博客园

猜您喜欢:

附下载 | 《Python进阶》中文版

附下载 | 经典《Think Python》中文版

附下载 | 《Pytorch模型训练实用教程》

附下载 | 最新2020李沐《动手学深度学习》

附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 | 超100篇!CVPR 2020最全GAN论文梳理汇总!

附下载 |《计算机视觉中的数学方法》分享

【通俗推导】人人都能看懂的LSTM介绍及反向传播算法