自然语言处理 | (25) 完全图解Seq2Seq with Attention模型
本文转载自知乎,原文见上述链接。本文通过图片,详细的画出了Seq2Seq with Attention模型的全部流程,帮助大家理解机器翻译、语音识别等任务的重要模型。
目录
1. 大框架
Seq2Seq是一个Encoder-Decoder结构的网络,他的输入是一个序列,输出也是一个序列,如机器翻译中输入可能是一段英文,输出则是对应的中文。Encoder中将一个可变长度的信号序列编码为固定长度的向量表达(代表输入序列的语义特征),Decoder则将这个固定长度的向量解码为可变长度的目标信号序列。
输入序列:
输出序列:
2. 详细图及公式
- (1)
,encoder端使用LSTM提取特征,其中
代表encoder LSTM每一个时间步骤上的隐藏状态,
表示encoder LSTM每一个时间步骤上输入序列中的一个term(词)的嵌入表示。
,
是输入序列的长度,即encoder LSTM需要经过的时间步骤数。 Encoder 端LSTM在每一个时间步骤上,接受输入序列中每一个词的嵌入表示,和上一个时间步骤上的隐藏状态
(
需要初始化),输出当前时间步骤上的隐藏状态
.(具体LSTM单元内部的运算细节,这里不再详细展开,之后会专门介绍)。
- (2)
,decoder端使用LSTM进行解码,其中
代表decoder LSTM每一个时间步骤上的隐藏状态,
表示目标序列中term(单词)的嵌入表示。Decoder 端LSTM,在当前时间步骤上,接受目标序列中上一个单词的嵌入表示
(
需要初始化,表示序列的开始)和上一个时间步骤上的隐藏状态
,输出当前时间步骤上的隐藏状态
,传给下一个时间步骤,并且预测当前时间步骤上的输出,即目标序列中下一个单词出现的概率(具体LSTM单元内部的运算细节,这里不再详细展开,之后会专门介绍).
- (3)
, 在decoder LSTM的每一个时间步骤上,都会计算一个上下文向量 context vector,它是对encoder每一个时间步骤上输出的隐藏状态
做一个加权平均得到的。该context vector(第一个时间步骤上的context vector初始化为0)可以和下一个时间步骤上目标序列中单词的嵌入进行拼接,作为输入。
- (4)
,通过结合decoder端LSTM当前时间步骤上的隐藏状态
和encoder端每个时间步骤上的隐藏状态
,通过score函数(之后会详细介绍运算细节)来计算decoder端LSTM当前时间步骤上的得分
. 该分数通过以下的公式来计算encoder LSTM中每一个时间步骤上隐藏状态
所对应的权重
,
,通过(3)中的公式计算当前时间步骤上的context vector,一方面和当前时间步骤的隐藏状态进行拼接进行输出预测,另一方面和目标序列中单词的嵌入进行拼接,作为下一个时间步骤的输入。
下一个时间步骤:
- (5)
,将decoder LSTM当前时间步骤上计算的context vector
和当前时间步骤上的隐藏状态
进行拼接,经过一个全联接层(使用tanh**函数),得到
. 其他时间步骤也是同样的操作。
- (6)
,再经过一个softmax层(输出层使用softmax**函数),计算最后的输出概率。
注意下图中的红色部分,在训练阶段是目标序列中每个词的嵌入,来预测下一个词出现的概率;在测试阶段,我们不知道目标序列,所以要根据上一个时间步骤的输出概率,进行采样,得到一个单词,用该单词的嵌入作为当前时间步骤的输入:
三、score的计算方法
在luong中提到了三种score的计算方法,这里图解前两种:
- Attention score function: dot
输入是encoder中所有隐藏状态 H:大小为(hidden dim,sequence_length()). decoder在一个时间步骤上的隐藏状态 s:大小为(hidden dim,1)
第一步:旋转H为(sequence_length(),hidden dim)与s点乘得到一个大小为(sequence_length(
),1)的分数向量e
第二步:对分数向量做一个softmax得到一个和为1的权重向量(sequence_length(
),1)
第三步:将H中的所有隐藏状态(hidden dim,sequence_length())和第二步计算的权重向量
(sequence_length(
),1)进行点乘(加权求和),得到一个大小为(hidden dim,1)的context vector
- Attention score function:general
输入是encoder中所有隐藏状态 H:大小为(hidden dim1,sequence_length()). decoder在一个时间步骤上的隐藏状态 s:大小为(hidden dim2,1)。此时两个hidden state的大小可能不一样。
第一步:旋转H为(sequence_length(),hidden dim1)与一个大小为(hidden dim1,hidden dim2)的矩阵
点乘,再与s点乘得到一个大小为(sequence_length(
),1)的分数向量e
第二步:对分数向量做一个softmax得到一个和为1的权重向量(sequence_length(
),1)
第三步:将H中的所有隐藏状态(hidden dim1,sequence_length())和第二步计算的权重向量
(sequence_length(
),1)进行点乘(加权求和),得到一个大小为(hidden dim1,1)的context vector
四、总结
看懂一个模型的最好办法就是在心里想一遍从输入到模型到输出每一个步骤里,tensor是如何流动的。
上述提到的Seq2Seq with Attention都是最基本的结构,相当于积木,在实际上使用时,会产生各种各样的变体,如encoder端采用双向LSTM进行编码,或堆叠多层双向LSTM,或把LSTM换成GRU,而且Decodrer端也可以堆叠多层单向LSTM或GRU,此外之间的Attention机制也会有一些复杂的变体。但是我们只有把这些基础结构学好,学会每块积木的用途和原理,接下里就可以灵活的使用这些积木,基于自己的应用,搭建出各种复杂的网络结构了。