Pytorch学习笔记之LSTM
Pytorch学习笔记之LSTM
看了理解LSTM这篇博文,在这里写写自己对LSTM网络的一些认识!。
- RNN
- 网络计算过程
Recurrent Neural Networks
人类并不是每时每刻都从一片空白的大脑开始他们的思考。在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见词的理解来推断当前词的真实含义。我们不会将所有的东西都全部丢弃,然后用空白的大脑进行思考。我们的思想拥有持久性。
传统的神经网络并不能做到这点,看起来也像是一种巨大的弊端。例如,假设你希望对电影中的每个时间点的时间类型进行分类。传统的神经网络应该很难来处理这个问题——使用电影中先前的事件推断后续的事件。
RNN 解决了这个问题。RNN 是包含循环的网络,允许信息的持久化
这是一个经典的RNN的流程图。
1. LSTM网络
经典的LSTM的流程图:
相信大家都看过这个图(盗用别人的图)。
再来一段公式,就是下面的,公式来自Pytorch。
is the hidden state at time , is the cell state at time , is the input at time , is the hidden state of the previous layer at time or the initial hidden state at time , and , , , are the input, forget, cell, and output gates, respectively. is the sigmoid function.
2. 内部计算分析
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
可以看到参数的大小变成了(4*20,10),是标准RNN的四倍。原因是这里它包括了四个参数矩阵、、、,它们每一个都是(20×10),输入的维度大小是(10×1), 这样 , , , 的维度都是(20×1),公式(5)(6)的运算应该是叉积(元素积),这样得到的和的维度才能是20。
如上图所示hn和cn的最后一维都是20。注意这里的LSTM网络是单向,双向的要*2。蟹蟹!