Pytorch学习笔记之LSTM

Pytorch学习笔记之LSTM


看了理解LSTM这篇博文,在这里写写自己对LSTM网络的一些认识!。

  • RNN
  • 网络计算过程

Recurrent Neural Networks

人类并不是每时每刻都从一片空白的大脑开始他们的思考。在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义。我们不会将所有的东西都全部丢弃,然后用空白的大脑进行思考。我们的思想拥有持久性。
传统的神经网络并不能做到这点,看起来也像是一种巨大的弊端。例如,假设你希望对电影中的每个时间点的时间类型进行分类。传统的神经网络应该很难来处理这个问题——使用电影中先前的事件推断后续的事件。
RNN 解决了这个问题。RNN 是包含循环的网络,允许信息的持久化

Pytorch学习笔记之LSTM
这是一个经典的RNN的流程图。


1. LSTM网络

经典的LSTM的流程图:

Pytorch学习笔记之LSTM

相信大家都看过这个图(盗用别人的图)。
再来一段公式,就是下面的,公式来自Pytorch。
Pytorch学习笔记之LSTM
hth_t is the hidden state at time tt , ctc_t is the cell state at time tt , xtx_t is the input at time tt, h(t1)h_{(t-1)} is the hidden state of the previous layer at time t1t-1 or the initial hidden state at time 00 , and iti_t , ftf_t , gtg_t , oto_t are the input, forget, cell, and output gates, respectively. σ\sigma is the sigmoid function.

2. 内部计算分析

rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

Pytorch学习笔记之LSTM

可以看到参数的大小变成了(4*20,10),是标准RNN的四倍。原因是这里它包括了四个参数矩阵WiiW_{ii}WifW_{if}WigW_{ig}WioW_{io},它们每一个都是(20×10),输入的维度大小是(10×1), 这样iti_t , ftf_t , gtg_t , oto_t 的维度都是(20×1),公式(5)(6)的运算应该是叉积(元素积),这样得到的ctc_thth_t的维度才能是20。

Pytorch学习笔记之LSTM
如上图所示hn和cn的最后一维都是20。注意这里的LSTM网络是单向,双向的要*2。蟹蟹!