理解Transformer架构 .02

通过提问题的方式,学习一下Bert中的Transformer架构,其中这篇文章与《理解Transformer架构 .01》的内容互为补充:

目录

1. Transformer的结构是什么样?

1.1  encoder端与decoder端总览

1.2 encoder端各个子模块

1.2.1 多头 self-attention 模块

1.2.2 前馈神经网络层

1.3 decoder端各个子模块

1.3.1 多头self-attention 模块

1.3.2 多头encoder-decoder attention 交互模块

1.3.3 前馈神经网络模块(Feed-Forward Net,FFN)

1.4 Add &Norm 模块

1.5 position encoding(位置编码)

2. Transformer中一直强调的self-attention是什么?self-attention的计算过程?为什么它能发挥如此重大的作用?self-attention为什么要用Q、K、V,仅仅使用Q、V/K、V或者V不行吗?

2.1 self-attention是什么?

2.2 self-attention的计算过程?

2.3 关于self-attention为什么它能发挥如此大的作用?

2.4 关于self-attention为什么要使用Q、K、V,仅仅使用Q、K/V、V或者V为什么不行?

3. 为什么要使用多头self-attention?

4. Transformer相比于RNN/LSTM,有什么优势?

5. Transformer的训练过程

6. self-attention公式中的归一化有什么作用?


1. Transformer的结构是什么样?

Transformer本身是一个典型的encoder-decoder模型,如果从模型层面来看,Transformer实际上就像一个seq2seq attention的模型:

                                                         理解Transformer架构 .02

1.1  encoder端与decoder端总览

  • encoder端由N个相同的大模块堆叠而成(论文中N=6),其中每个大模块又由两个子模块组成,分别是self-attention模块以及一个前馈神经网络层。

注意:encoder端每个大模块的输入是不一样的,第一个大模块的输入的输入序列的embedding(可以通过word2vec训练而来,再加上位置编码),其余模块的输入均是前一个模块的输出,最后一个模块的输出作为整个endocer端的输出

  • decoder端同样由N个相同的大模块堆叠而成,其中每个大模块由3个子模块组成,分别是self-attention模块、encoder-decoder attention交互模块以及一个前馈神经网络模块。

注意:decoder端每个大模块接收的输入是不一样的,其中第一个大模块(最底下那个)在模型训练和测试时接收的输入是不一样的,并且每次训练时接收的输入是不一样的(因为最底层的输入有 "shifted right"这一限制);其余模块的输入都是前一个模块的输出;最后一个模块的输出作为整个decoder端的输出

对于decoder端第一个大模块,其训练及测试时的输入为:

训练时每次的输入为上次的输入加上输入序列向后移一位的ground truth(例如,每向后移一位就是一个新的单词,那么就加上其对应的embedding),特别的,当decoder的time step为1时,也就是第一次接收输入,其输入为一个特殊的token,可能是目标序列开始的token如[CLS],也可能是源序列结尾的token如[SEP],其目标是预测下一个位置的单词是什么;如当time step为1时,就是预测输入序列的第一个单词是什么

注意:在实际现实中可能不会这样每次动态的输入,而是一次性把目标序列的embedding统统输入第一个大模块中,然后在多头attention模块中对序列进行mask即可。二在测试的时候,先生成第一个位置的输出,有了这个之后,第二次预测时,再将其加入输入序列,以此类推直至预测结束。

1.2 encoder端各个子模块

1.2.1 多头 self-attention 模块

在介绍多头self-attention之前,先看看self-attention模块,其结构图如下所示:

                                                                           理解Transformer架构 .02

上面的self-attention可以被描述为将query和key-value键值对的一组集合映射到输出,query、keys、values和输出都是向量,其中,query和keys的维度均为 理解Transformer架构 .02 ,values的维度是 理解Transformer架构 .02 (论文中  理解Transformer架构 .02 ),输出为被计算为values的加权和,其中分配给每个value的权重由query 与对应key的相似性函数计算得来。这种attention的形式被称为 “Scaled Dot-Product Attention”,对应到公式的形式为:

                                                                   理解Transformer架构 .02

而多头self-attention模块,则是将Q,K,V通过参数矩阵映射后(映射方式就是在Q,K,V后分别接一个全连接层),然后在做self-attention,将这个过程重复h次(论文中h=8),最后再将所有的结果拼接起来,再送入一个全连接层即可,图示如下:

                                                               理解Transformer架构 .02

对应到公式的形式为:

                                      理解Transformer架构 .02

其中,理解Transformer架构 .02

1.2.2 前馈神经网络层

前馈神经网络模块由两个线性变换组成,中间有一个ReLU**函数,对应到公式的形式为:

                                                  理解Transformer架构 .02

论文中前馈神经网络模块输入和输出的维度均为 理解Transformer架构 .02 ,其内层的维度 理解Transformer架构 .02 。

1.3 decoder端各个子模块

1.3.1 多头self-attention 模块

decoder端多头self-attention 模块与encoder端的一致,但是需要注意的是decoder端的多头self-attention需要mask,因为它在预测时,是“看不到未来的序列的”,所以要讲当前预测的单词(token)及其之后的单词token全部mask掉。

1.3.2 多头encoder-decoder attention 交互模块

多头encoder-decoderattention交互模块的形式与多头self-attention模块一致,唯一不同的是其Q,K,V矩阵的来源,其Q矩阵来源于下面子模块的输出(对应到图中即为masked多头self-attention模块经过Add&Norm后的输出),而K,V矩阵则来源于整个encoder端的输出,目的就是让decoder端的单词token给予encoder端对应的单词token更多的关注。

1.3.3 前馈神经网络模块(Feed-Forward Net,FFN)

该部分与encoder端的一致。

1.4 Add &Norm 模块

Add & Norm模块在encoder端和decoder端每个子模块的后面,其中Add是残差连接(残差连接可参考上一篇文章),Norm表示LayerNorm(在每一层计算每一个样本的均值与方差,而BatchNorm是为每一各batch计算每一层的样本均值与方差),因此encoder端和decoder端每个子模块的实际输出为:LayerNorm (x+ Sublayer(x)),其中Sublayer(x)为子模块的输出,而x+Sublayer(x)代表残差。

1.5 position encoding(位置编码)

position encoding 添加到encoder端和decoder端最底部的输入embedding。position encoding具有与embedding相同的维度 理解Transformer架构 .02 :

                                          理解Transformer架构 .02

,因此可以对两者进行求和。

具体做法是使用不同频率的正弦和余弦函数,公式如下:

                                                          理解Transformer架构 .02

                                                       理解Transformer架构 .02

其中pos为位置,i为维度,之所以选择这个函数是因为任意位置 理解Transformer架构 .02 可以表示为 理解Transformer架构 .02 的线性函数,因为三角函数有这样的特性:

                                                      理解Transformer架构 .02

                                                     理解Transformer架构 .02

注意:transformer中的positional encoding 不是通过网络学习得来的,而是直接使用上述公式计算得来,论文中也实验了利用网络学习positional encoding ,发现结果与上述基本一致,但是论文中选择了正弦和余弦函数版本,因为三角函数不受序列长度的限制,也就是可以对 更长的序列进行表示。

2. Transformer中一直强调的self-attention是什么?self-attention的计算过程?为什么它能发挥如此重大的作用?self-attention为什么要用Q、K、V,仅仅使用Q、V/K、V或者V不行吗?

2.1 self-attention是什么?

self-attention,也叫intra-attention,是一种通过自身和自身相关联的attention机制,从而得到一个更好的representation来表达自身,self-attention可以看成一般attention的一种特殊情况。在self-attention中,Q=K=V,序列中的每个单词token和该序列中其余单词token进行attention计算。self-attention的特点在于可以无视token之间的距离直接计算依赖关系,从而能够学习到序列的内部结构,实现起来也比较简单。

2.2 self-attention的计算过程?

参考1.2.1节的内容。

2.3 关于self-attention为什么它能发挥如此大的作用?

self-attention是一种自身和自身相关联的attention机制,这样能够得到一个更好的representation来表达自身,在多数情况下,自然会对下游任务有一定的促进作用,但是Transformer效果显著及其强大的特征抽取能力是否完全归功于其self-attention模块,还是存在 一定争议的。

很明显,模型中引入self-attention后会更容易捕获句子中长距离的相互依赖的特征,因为如果是RNN或者LSTM,需要依次序列计算,对于远距离的相互依赖的特征,需要经过若干时间步骤的信息累积才能将两者联系起来,而距离越远,有效捕获的可能性越小。

但self-attention在计算过程中会直接将句子中任意两个单词的联系通过一个计算步骤直接联系起来,所以远距离依赖特征之间的距离被极大缩减,有利于有效地利用这些特征。除此之外,self-attention对于增加计算的并行性也有直接帮助作用。

2.4 关于self-attention为什么要使用Q、K、V,仅仅使用Q、K/V、V或者V为什么不行?

self-attention使用Q、K、V,这样三个参数独立,模型的表达能力和灵活性显然比只用Q、V或者只用V要好些,当然主流attention的做法还有很多种,比如说seq2seq with attention也就只有hidden state 来做相似性的计算,处理不同的任务,attention的做法有细微的不同,但是主题思想还是一致的。

3. 为什么要使用多头self-attention?

论文中说到进行multi-headed self-attention的原因是将模型分为多个头,形成多个子空间,可以让模型去关注不同方面的信息,最后再将各个方面的信息综合起来

4. Transformer相比于RNN/LSTM,有什么优势?

RNN系列的模型T时刻隐层状态的计算,依赖两个输入,一个T时刻的句子输入单词 理解Transformer架构 .02 ,另一个是T-1时刻的隐层状态的输出 理解Transformer架构 .02 ,这是最能体现RNN本质特征的一点,RNN的历史信息是通过这个信息传输渠道往后传输的。而RNN 并行计算的问题就出在这里,因为t时刻的计算依赖t-1时刻的隐层计算结果,而t-1时刻的计算依赖t-2时刻的隐层计算结果,如此下去就形成了所谓的序列依赖关系。

5. Transformer的训练过程

  1. 首先,encoder端得到输入的encoding表示,将其输入到decoder端做交互式attention;
  2. 之后再decoder端接收器相应的输入,经过多头self-attention之后,结合encoder端的输出;
  3. 再经过FFN,得到decoder端的输出之后,最后经过一个线性全连接层,就可以通过softmax来预测下一个单词token,
  4. 然后根据softmax多分类的损失函数,将loss反向传播即可。

注意:encoder端可以并行计算,一次性将输入序列全部encoding出来,如对于某个序列理解Transformer架构 .02 ,self-attention可以直接计算理解Transformer架构 .02的点乘结果,而RNN系列的模型就必须按照顺序从x1计算到xn;但是decoder端不是这样的,它需要将被预测单词一个接一个预测出来。

6. self-attention公式中的归一化有什么作用?

计算过程中,随着 理解Transformer架构 .02 的增大,理解Transformer架构 .02 点积的结果也随之增大,这样会将softmax函数推入梯度非常小的区域,使得收敛困难(可能会出现梯度消失的情况)。