GRU(Gated Recurrent Unit)初探
在处理序列任务的时候,由于RNN
存在梯度消失和梯度爆炸的原因:
梯度消失:
RNN
梯度消失是因为**函数tanh
函数的倒数在0
到1
之间,反向传播时更新前面时刻的参数时,当参数W
初始化为小于1
的数,则多个(tanh
函数’ *W
)相乘,将导致求得的偏导极小(小于1
的数连乘),从而导致梯度消失。
梯度爆炸:当参数初始化为足够大,使得tanh
函数的倒数乘以W
大于1
,则将导致偏导极大(大于1
的数连乘),从而导致梯度爆炸。
LSTM
在1997年就提出来,通过门控单元来解决这个问题。在2014年GRU
提出,相比LSTM
,GRU
能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU
。
GRU
的输入输出结构与普通的RNN
是一样的,主要的修改在于内部的门控单元:
在图中的红色框中的部分分别是重置门和更新门。
分别从上到下来看看公式:
- 更新门
z
t
z_t
zt的数据流来自
h
t
−
1
h_{t-1}
ht−1和
x
t
x_t
xt,进行一个连接,然后输入到
sigmod
**函数中,构成更新门部分 z t z_t zt。 - 同理,重置门
r
t
r_t
rt的数据流来自
h
t
−
1
h_{t-1}
ht−1和
x
t
x_t
xt,进行一个连接,然后输入到
sigmod
**函数中。 - 我们跟随从
r
t
r_t
rt到
⨂
\bigotimes
⨂到
tanh
的数据流,也就是将前一步的 r t r_t rt和 h t − 1 h_{t-1} ht−1做一个对应 ⨂ \bigotimes ⨂(乘积),即: r t ∗ h h − 1 r_t * h_{h-1} rt∗hh−1,然后来自 x t x_t xt的部分进行连接,最后输入到tanh
**函数中。 - h t h_t ht的部分,我们观察数据流,可以知道 ⨁ \bigoplus ⨁是两个数据流部分的加和,而对应的部分分别是来自 z t z_t zt更新门的部分,分别是 ( 1 − z t ) (1-z_t) (1−zt)和 z t z_t zt,也就是最后的公式四。