Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention阅读笔记
1 Abstract
最近在NLP领域,普遍都通过增大神经网络的深度来增强模型的性能和模型容量,然而简单的堆叠神经网络例如在Transformer模型中,会使模型变得难以训练,收敛性会变差并且还会使计算复杂度增高。
在本文中,作者认为深度模型难以收敛主要是因为梯度消失现象,而这一现象在Transformer中主要是由于残差链接和层正则化之间的相互影响。
在本文中,作者提出了两个方法来解决上述问题:
(1) Depth-scaled initialization(DS-Init) 该模型主要通过在初始化阶段减少模型参数的方差,并以此来减少残差链接输出的方差。从而缓解反向传播过程中通过正则化层时的梯度问题。
(2) Merged Attention sublayer(MAtt) 该模型主要为了缓解decoder的计算复杂性。作者将decoder端的self-attention和encoder-decoder attention进行组合,从而形成MAtt
2 Introduction
在图像领域,好的模型一般都有很深的深度,而在NLP领域,尤其是Transformer却难以加深层次。作者对比了近期其他研究者的工作,认为深度Transformer中遭受了很严重的梯度消失问题(如下图所示),作者认为这种现象是由于层间的残差链接以及层正则化的影响。
由图中可以看到,当模型的层数不断加深时,浅层网络中的梯度消失的现象十分明显。
具体来说,DS-Init通过在初始化阶段使用一个与层数相关的折扣系数(discount factor )
另一个十分重要的方面就是随着深度的增加,Transformer模型的训练和解码计算复杂度会急剧上升,为了缓解这个问题,作者使用了一种叫做merged attention network的模型来解决深层网络的计算复杂的问题。这个模型将decoder中分离的encoder-decoder attention和self attention进行了合并。合并后的子层由两部分,一部分是基于平均的attention,另一部分是encoder-decoder attention。作者也对ANN进行了参数上和结构上的简化,可以很好的缓解计算开销。
3 Vanishing Gradient Analysis
前面也反复提到多次,作者认为,梯度消失的是由于残差链接RC和层正则化LN的相互作用而产生的。在文中,作者深度分析了,为什么会产生梯度消失的现象,在分析之前作者首先给出了残差链接和层正则化的数学模型:
其中RC表示残差链接,LN表示层正则化。通过给定的传播方式,可以求得其反向传播的梯度分别为:
其中,表示LN层的错误,diag(·)表示对input建立对角矩阵,I表示单位矩阵。上式分别利用链式法则求得了相对于各个输出产生的错误。
随后作者比较了模型的错误比率和LN层以及RC层的错误比率:
通过观察上式可知,只有当时,梯度消失的问题才不会对模型造成太大影响。同时作者比较了based-transformer和加入了DS-Init的Transformer的模型比率:
其中self表示self-attention,cross表示encoder-decoder attention,FFN表示Feed forward network。
4 Depth-scaled Initialization
作者通过实验得出self-attention子层具有>1的趋势,而FFN具有<1的趋势。而特殊的,encoder-decoder attention的值偏小,作者认为,这是由于在反向传播过程中,只有Q在decoder中传递,而K与V都传递到了encoder中。
所以,encoder的训练相对于decoder会容易一些,这也在BERT和GPT的成功上有所体现。
作者将这种梯度消失的原因归因于残差链接RC的输出方差过大,并提出了一种参数初始化方法:
其中是原来常用的参数初始化方法:
是该模型的超参数,为方法增加了灵活性。
均匀分布的方差为D(x)=(b-a)²/12,故可知,原来模型的参数由 变为了.通过该式子深层网络的RC部分具有较小的方差,从而使得梯度可以传至浅层部分。
5 Merged Attention Model
对于非常深的模型,会造成非常大的计算压力,这回造成模型过长的训练和推断的时间。为了解决这个问题,作者在文中提出了一个Merged Attention模型,这个模型将基于平均的self-attention(average-based self-attention)装入了encoder-decoder层中,其具体结构如下图所示:
今天先写到这,大体的模型已经记录完毕,日后会更新一些细节进来
参考论文:
Improving Deep Transformer with Depth-Scaled Initialization and Merged Attention