LSTM模型与前向反向传播算法笔记


这篇文章是看了刘建平老师的LSTM模型与前向反向传播算法后的笔记,同时参考这两篇文章,包括一些公式推导,都是自己的理解,如有错误,欢迎指出。


LSTM前向传播算法

这边直接给出前向传播过程中的公式计算,具体的可以参考上面文章:

LSTM模型与前向反向传播算法笔记
其中,
x(t)n×1,h(t)m×1Vl×m,cl×1,y^(t)l×1o(t),f(t),i(t),a(t),C(t)m×1Wf,Wi,Wa,Wom×mUf,Ui,Ua,Uom×nbf,bi,ba,bom×1 \boldsymbol x^{(t)}-n\times1,\boldsymbol h^{(t)}-m\times1 \\ \boldsymbol V-l\times m,\boldsymbol c-l\times 1,\hat \boldsymbol y^{(t)}-l\times 1 \\ \boldsymbol o^{(t)},\boldsymbol f^{(t)},\boldsymbol i^{(t)},\boldsymbol a^{(t)},\boldsymbol C^{(t)}-m\times 1\\ \boldsymbol W_f,\boldsymbol W_i,\boldsymbol W_a,\boldsymbol W_o-m\times m\\ \boldsymbol U_f,\boldsymbol U_i,\boldsymbol U_a,\boldsymbol U_o-m\times n\\ \boldsymbol b_f,\boldsymbol b_i,\boldsymbol b_a,\boldsymbol b_o-m\times 1
L(t)L(t)分为两部分,一部分为l(t)l(t),另一部分为tt时刻之后的损失L(t+1)L(t+1)
L(t)={l(t)+L(t+1),t<τl(t),   t=τ L(t) = \begin{cases} l(t)+L(t+1),\quad t<\tau\\ l(t),\quad\quad\quad\quad\quad\ \ \ t=\tau \end{cases}

所以,当t=τt=\tau时,我们有δh(τ)=Lh(τ)=VT(y^(τ)y(τ))\boldsymbol\delta_h^{(\tau)}=\frac{\partial L}{\partial \boldsymbol h^{(\tau)}}=\boldsymbol V^T(\hat\boldsymbol y^{(\tau)}-\boldsymbol y^{(\tau)}),具体的推导可以参考RNN,此外,
dL=(Lh(τ))Tdh(τ)=(δh(τ))Td(o(τ)tanh(C(τ)))=tr((δh(τ))T(o(τ)dtanh(C(τ))))=tr((δh(τ)o(τ))Tdtanh(C(τ)))=tr((δh(τ)o(τ))T(1tanh2(C(τ)))dC(τ))=tr([δh(τ)o(τ)(1tanh2(C(τ)))]TdC(τ)) \begin{aligned} dL&=\left(\frac{\partial L}{\partial \boldsymbol h^{(\tau)}}\right)^Td\boldsymbol h^{(\tau)}\\ &=\left(\boldsymbol\delta_h^{(\tau)}\right)^Td\left(\boldsymbol o^{(\tau)}\odot tanh(\boldsymbol C^{(\tau)})\right)\\ &=tr\left(\left(\boldsymbol\delta_h^{(\tau)}\right)^T\left(\boldsymbol o^{(\tau)}\odot dtanh(\boldsymbol C^{(\tau)})\right)\right)\\ &=tr\left(\left(\boldsymbol\delta_h^{(\tau)}\odot \boldsymbol o^{(\tau)}\right)^Tdtanh(\boldsymbol C^{(\tau)})\right)\\ &=tr\left(\left(\boldsymbol\delta_h^{(\tau)}\odot \boldsymbol o^{(\tau)}\right)^T\left(\boldsymbol 1-tanh^2(\boldsymbol C^{(\tau)})\right)\odot d\boldsymbol C^{(\tau)}\right)\\ &=tr\left(\left[\boldsymbol\delta_h^{(\tau)}\odot \boldsymbol o^{(\tau)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(\tau)})\right)\right]^Td\boldsymbol C^{(\tau)}\right) \end{aligned}
所以,δC(τ)=LC(τ)=δh(τ)o(τ)(1tanh2(C(τ)))\boldsymbol\delta_C^{(\tau)}=\frac{\partial L}{\partial \boldsymbol C^{(\tau)}}=\boldsymbol\delta_h^{(\tau)}\odot \boldsymbol o^{(\tau)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(\tau)})\right)
接下来由t+1t+1项往前推导:
dL=dl(t)+dL(t+1)=(l(t)h(t))Tdh(t)+(L(t+1)h(t+1))Tdh(t+1)=(l(t)h(t))Tdh(t)+(L(t+1)h(t+1))T(h(t+1)h(t))Tdh(t)δh(t)=Lh(t)=l(t)h(t)+h(t+1)h(t)δh(t+1)h(t+1)h(t)dh(t+1)=do(t+1)tanh(C(t+1))=tanh(C(t+1))do(t+1)+o(t+1)dtanh(C(t+1))=tanh(C(t+1))o(t+1)(1o(t+1))dWoh(t)+o(t+1)(1tanh2(C(t+1)))dC(t+1)=diag[tanh(C(t+1))o(t+1)(1o(t+1))]Wodh(t)+diag[o(t+1)(1tanh2(C(t+1)))]dC(t+1)C(t+1)=C(t)f(t+1)+i(t+1)a(t+1)dC(t+1)=f(t+1)dC(t)+C(t)df(t+1)+a(t+1)di(t+1)+i(t+1)da(t+1)=f(t+1)dC(t)+C(t)f(t+1)(1f(t+1))dWfh(t)+a(t+1)i(t+1)(1i(t+1))dWih(t)+i(t+1)(1tanh2(a(t+1)))dWah(t)=f(t+1)dC(t)+diag[C(t)f(t+1)(1f(t+1))]Wfdh(t)+diag[a(t+1)i(t+1)(1i(t+1))]Widh(t)+diag[i(t+1)(1tanh2(a(t+1)))]Wadh(t) \begin{aligned} dL&=dl(t)+dL(t+1)\\ &=\left(\frac{\partial l(t)}{\partial \boldsymbol h^{(t)}}\right)^Td\boldsymbol h^{(t)}+\left(\frac{\partial L(t+1)}{\partial \boldsymbol h^{(t+1)}}\right)^Td\boldsymbol h^{(t+1)}\\ &=\left(\frac{\partial l(t)}{\partial \boldsymbol h^{(t)}}\right)^Td\boldsymbol h^{(t)}+\left(\frac{\partial L(t+1)}{\partial \boldsymbol h^{(t+1)}}\right)^T\left(\frac{\partial \boldsymbol h^{(t+1)}}{\partial \boldsymbol h^{(t)}}\right)^Td\boldsymbol h^{(t)}\\ 所以:\\ \boldsymbol \delta_h^{(t)}&=\frac{\partial L}{\partial \boldsymbol h^{(t)}}\\ &=\frac{\partial l(t)}{\partial \boldsymbol h^{(t)}}+\frac{\partial \boldsymbol h^{(t+1)}}{\partial \boldsymbol h^{(t)}}\boldsymbol \delta_h^{(t+1)}\\ 其中,\frac{\partial \boldsymbol h^{(t+1)}}{\partial \boldsymbol h^{(t)}}如下得出:\\ d\boldsymbol h^{(t+1)}&=d\boldsymbol o^{(t+1)}\odot tanh(\boldsymbol C^{(t+1)})\\ &=tanh(\boldsymbol C^{(t+1)})\odot d\boldsymbol o^{(t+1)}+\boldsymbol o^{(t+1)}\odot dtanh(\boldsymbol C^{(t+1)})\\ &=tanh(\boldsymbol C^{(t+1)})\odot \boldsymbol o^{(t+1)}\odot (\boldsymbol 1-\boldsymbol o^{(t+1)})\odot d\boldsymbol W_o\boldsymbol h^{(t)}\\ &+\boldsymbol o^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(t+1)})\right)\odot d\boldsymbol C^{(t+1)}\\ &=diag\left[tanh(\boldsymbol C^{(t+1)})\odot \boldsymbol o^{(t+1)}\odot (\boldsymbol 1-\boldsymbol o^{(t+1)})\right]\boldsymbol W_od\boldsymbol h^{(t)}\\ &+diag\left[\boldsymbol o^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(t+1)})\right)\right]d\boldsymbol C^{(t+1)}\\ \boldsymbol C^{(t+1)}&=\boldsymbol C^{(t)}\odot \boldsymbol f^{(t+1)}+\boldsymbol i^{(t+1)}\odot \boldsymbol a^{(t+1)}\\ d\boldsymbol C^{(t+1)}&=\boldsymbol f^{(t+1)}\odot d\boldsymbol C^{(t)}+\boldsymbol C^{(t)}\odot d\boldsymbol f^{(t+1)}+\boldsymbol a^{(t+1)}\odot d\boldsymbol i^{(t+1)}+\boldsymbol i^{(t+1)}\odot d\boldsymbol a^{(t+1)}\\ &=\boldsymbol f^{(t+1)}\odot d\boldsymbol C^{(t)}+\boldsymbol C^{(t)}\odot \boldsymbol f^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol f^{(t+1)}\right)d\boldsymbol W_f\boldsymbol h^{(t)}\\ &+\boldsymbol a^{(t+1)}\odot \boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol i^{(t+1)}\right)\odot d\boldsymbol W_i\boldsymbol h^{(t)}\\ &+\boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol a^{(t+1)})\right)\odot d\boldsymbol W_a\boldsymbol h^{(t)}\\ &=\boldsymbol f^{(t+1)}\odot d\boldsymbol C^{(t)}+diag\left[\boldsymbol C^{(t)}\odot \boldsymbol f^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol f^{(t+1)}\right)\right]\boldsymbol W_fd\boldsymbol h^{(t)}\\ &+diag\left[\boldsymbol a^{(t+1)}\odot \boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol i^{(t+1)}\right)\right]\boldsymbol W_id\boldsymbol h^{(t)}\\ &+diag\left[\boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol a^{(t+1)})\right)\right]\boldsymbol W_ad\boldsymbol h^{(t)} \end{aligned}
dC(t+1)d\boldsymbol C^{(t+1)}代入dh(t+1)d\boldsymbol h^{(t+1)}得到一个很庞大的式子,你们可以自己计算,这边直接给出答案:
h(t+1)h(t)=(Wo)Tdiag[tanh(C(t+1))o(t+1)(1o(t+1))]+(Wf)Tdiag[C(t)f(t+1)(1f(t+1))ΔC]+(Wi)Tdiag[a(t+1)i(t+1)(1i(t+1))ΔC]+(Wa)Tdiag[i(t+1)(1tanh2(a(t+1)))ΔC]ΔC=o(t+1)(1tanh2(C(t+1))) \begin{aligned} \frac{\partial \boldsymbol h^{(t+1)}}{\partial \boldsymbol h^{(t)}}&=\left(\boldsymbol W_o\right)^Tdiag\left[tanh(\boldsymbol C^{(t+1)})\odot \boldsymbol o^{(t+1)}\odot (\boldsymbol 1-\boldsymbol o^{(t+1)})\right]+\left(\boldsymbol W_f\right)^Tdiag\left[\boldsymbol C^{(t)}\odot \boldsymbol f^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol f^{(t+1)}\right)\odot \Delta\boldsymbol C\right]\\ &+\left(\boldsymbol W_i\right)^Tdiag\left[\boldsymbol a^{(t+1)}\odot \boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol i^{(t+1)}\right)\odot \Delta\boldsymbol C\right]+\left(\boldsymbol W_a\right)^Tdiag\left[\boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol a^{(t+1)})\right)\odot \Delta\boldsymbol C\right]\\ 其中,\Delta\boldsymbol C&=\boldsymbol o^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(t+1)})\right) \end{aligned}
有了δh(t)\boldsymbol \delta_h^{(t)}δC(t)\boldsymbol \delta_C^{(t)}就很容易得出来了:
LSTM模型与前向反向传播算法笔记

有了δh(t),δC(t)\boldsymbol \delta_h^{(t)},\boldsymbol \delta_C^{(t)},其他一些参数的梯度就很容易得出来了。