GRU(Gated Recurrent Unit)初探

在处理序列任务的时候,由于RNN存在梯度消失和梯度爆炸的原因:

梯度消失RNN梯度消失是因为**函数tanh函数的倒数在01之间,反向传播时更新前面时刻的参数时,当参数W初始化为小于1的数,则多个(tanh函数’ * W)相乘,将导致求得的偏导极小(小于1的数连乘),从而导致梯度消失。
梯度爆炸:当参数初始化为足够大,使得tanh函数的倒数乘以W大于1,则将导致偏导极大(大于1的数连乘),从而导致梯度爆炸。

LSTM在1997年就提出来,通过门控单元来解决这个问题。在2014年GRU提出,相比LSTMGRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU

GRU的输入输出结构与普通的RNN是一样的,主要的修改在于内部的门控单元:
GRU(Gated Recurrent Unit)初探
在图中的红色框中的部分分别是重置门和更新门。
分别从上到下来看看公式:

  1. 更新门 z t z_t zt的数据流来自 h t − 1 h_{t-1} ht1 x t x_t xt,进行一个连接,然后输入到sigmod**函数中,构成更新门部分 z t z_t zt
  2. 同理,重置门 r t r_t rt的数据流来自 h t − 1 h_{t-1} ht1 x t x_t xt,进行一个连接,然后输入到sigmod**函数中。
  3. 我们跟随从 r t r_t rt ⨂ \bigotimes tanh的数据流,也就是将前一步的 r t r_t rt h t − 1 h_{t-1} ht1做一个对应 ⨂ \bigotimes (乘积),即: r t ∗ h h − 1 r_t * h_{h-1} rthh1,然后来自 x t x_t xt的部分进行连接,最后输入到tanh**函数中。
  4. h t h_t ht的部分,我们观察数据流,可以知道 ⨁ \bigoplus 是两个数据流部分的加和,而对应的部分分别是来自 z t z_t zt更新门的部分,分别是 ( 1 − z t ) (1-z_t) (1zt) z t z_t zt,也就是最后的公式四。