RNN中梯度消失和爆炸的问题公式推导

RNN

首先来看一下经典的RRN的结构图,这里 xx 是输入 WW 是权重矩阵 (RNN的权重矩阵是共享的所以都是W) hh 是隐藏状态 yy是输出
RNN中梯度消失和爆炸的问题公式推导

RNN简单公式定义

ht=Wf(ht1)+W(hx)x[t] h_t = W*f(h_{t-1}) + W^{(hx)}*x_{[t]}
yt=W(S)f(ht) y_{t} = W^{(S)}*f(h_t)
其中,hth_t表示 t 时刻的隐藏状态 x[t]x_{[t]} 表示 t 时刻的输入 yty_t 表示 t 时刻的输出。我们记总体的error为 EE 那么 EE 有如下表达式:
E=t=1TEtW E = \sum_{t=1}^{T}\frac{\partial E_t}{\partial W}
总体的误差是所有时刻 t 的误差的累加。那么继续往下展开, 根据链式法则:
EtW=k=1tEtytythththkhkW \frac{\partial E_t}{\partial W} = \sum_{k=1}^{t}\frac{\partial E_t}{\partial y_t} \frac{\partial y_t}{\partial h_t}\frac{\partial h_t}{\partial h_k} \frac{\partial h_k}{\partial W}
继续往下展开有:
hthk=j=k+1thjhj1 \frac{\partial h_t}{\partial h_k} = \prod_{j=k+1}^{t}\frac{\partial h_j}{\partial h_{j-1}}
注意到:ht=Wf(ht1)+W(hx)x[t]h_t = W*f(h_{t-1}) + W^{(hx)}*x_{[t]},上式的每个偏导其实是一个Jacobian式

RNN中梯度消失和爆炸的问题公式推导

考虑Jacobians的范数,令:
hjhj1WTdiag[f(hj1)]βwβh ||\frac{\partial h_j}{\partial h_{j-1}} || \leq ||W^{T}|| *||diag[f'(h_{j-1})]|| \leq \beta_w*\beta_h
其中,βw,βh\beta_w ,\beta_h 表示正则化的上界。将上式回代到连乘的式子得:
hthk=j=k+1thjhj1(βwβh)tk ||\frac{\partial h_t}{\partial h_k} ||= ||\prod_{j=k+1}^{t}\frac{\partial h_j}{\partial h_{j-1}}|| \leq(\beta_w *\beta_h)^{t-k}
这里得 t 表示 time-step,也就是序列越长t会越大,即就变成了长期依赖的问题。注意到(βwβh)tk(\beta_w *\beta_h)^{t-k} 这项其实与矩阵的W的初始化有关,假设初始化一些非常小的数,W的范数也会变得很小,也就是βw\beta_w会变得比较小,那么随着t的增长,这一指数项会趋近于0而导致梯度消失,相反,如果初始化成为大于1的数,则随着t的增长,会导致梯度爆炸。