Attention各个维度计算方法

Attention各个维度计算方法
这里是对self-Attention具体的矩阵操作,包括怎么separate head,如何进行的QK乘法等细节做了记录,以便自己以后查看。

dot-product Attention

Attention各个维度计算方法
其中的 X n , d m o d e l X^{n,d_{model}} Xn,dmodel一般是seq序列,n为序列的长度, d m o d e l d_{model} dmodel为序列的emedding维度。在self-attention中 d k d_k dk d v d_v dv是相等的。

multi-head Attention

其中的 X b s , l e n g t h , e m b X^{bs,length,emb} Xbs,length,emb一般是输入的序列,维度的意义如名字所示。

  1. 首先用三个矩阵 W Q W K W V W_QW_KW_V WQWKWV分别对QKV嵌入一个新的维度,emb2也就是projection_dim,当然也可以保持原有的维度不变。
  2. Q b s , l e n g t h , e m b 2 Q^{bs,length,emb2} Qbs,length,emb2举例,需要将head分离出来,做法也就是对Q的最后一个维度reshape
  3. 对QK做矩阵乘法,这里注意,Q和K的维度是四个维度( Q b s , h e a d , l e n g h t , e m b 2 / / h e a d Q^{bs,head,lenght,emb2//head} Qbs,head,lenght,emb2//head),这里的乘法是保持bs和head不变只在最后两个维度做乘法,所以得到的Attention矩阵 A b s , h e a d , l e n g h t , l e n g t h A^{bs,head,lenght,length} Abs,head,lenght,length,这里的意义就是用户序列的每一个词都对其余的词有一个attention值。
  4. A V T AV^T AVT得到 Y b s , h e a d , l e n g t h , e m b 2 / / h e a d Y^{bs,head,length,emb2//head} Ybs,head,length,emb2//head,对Y进行reshape一下将head去掉恢复原来QKV的形状 Y b s , l e n g t h , e m b 2 Y^{bs,length,emb2} Ybs,length,emb2
    Attention各个维度计算方法