seq2seq + attention 详解
seq2seq + attention 详解
作者:xy_free 时间:2018.05.21
1. seq2seq模型
seq2seq模型最早可追溯到2014年的两篇paper [1, 2],主要用于机器翻译任务(MT)。seq2seq本质上是一种encoder-decoder框架,以翻译任务中的“英译汉”为例,模型首先使用编码器对英文进行编码,得到英文的向量化表示S,然后使用解码器对S进行解码,得到对应的中文。由于encoder与decoder两端处理的都是序列数据,所以被称为sequence-to-sequence,简称seq2seq。另外,目前应用最多的编/解码器是RNN(LSTM,GRU),但编/解码器并不限于RNN,如也有人拿MLP作为编码器。
paper[1, 2]的主要结构如下图:
2. attention模型
attention模型最早出现于cv领域,而首次用于解决nlp问题是在2014年[3],seq2seq+attention 应用于机器翻译任务。以英译汉为例,当解码器对英文进行解码时,是一个词一个词生成的,而所生成的每个词对应的英文部分应该是不同,换句话说就是,解码器解码时不同step所分配的注意力是不同的。 再举一个例子,如看图说话(用一句话描述一幅图),所生成的词语应该对应图中的不同部分,即解码器在解码时,应该给图中“合适”的部位,分配更多的注意力(权重)。
paper[3]的主要结构如下图:
红圈标识的是编码器,其中h代表源文本的语义表示;紫圈标识的解码器,其中s代表目标文本的序列状态。c表示注意力向量,用来在解码时,控制源文本不同位置的attention分配
3. seq2seq + attention
以paper[3] 为例,对seq2seq + attention 的计算过程,进行详细说明,见上图(Translation: Attention Mechanism)
1. 使用 Bi-GRU 作为编码器,得到源文本的向量表示
详解如下:
- 其中, 表示Bi-GRU,表示正向GRU的输出, 表示反向GRU的输出,[]表示串联
2.对进行解码,获得目标序列
模型所要生成的目标是个“词序列”,处理方式是每次生成一个词,迭代进行
其中f是 维度映射 + maxout,maxout是一种**函数,维度映射是把所生成的向量转化为词表大小
-
是目标序列上一个词的词向量
在模型训练阶段,有两种选择(按比例选):一种是真实的训练样本词向量,另一种是生成的词的词向量,前一种方式也被称为 teacher forcing
在模型测试阶段, 是指生成的词的词向量 - 是序列的当前状态,,其中 表示GRU
-
表示注意力分配,详细计算如下:
其中 都是待学习参数, 可以理解为 关于的一个加权平均值,权重为
4. attention 扩展
attention很火,paper[4] 提出了一种attention改良方案,将attention划分为了两种形式:global, local.
global方式认为attention应该在所有源文本上进行,而local方式认为attention仅应该在部分源文本上进行。global理念与paper[3]相同,具体计算方式如下图所示:
其中“concat” 与 paper[3] 中的计算方式相同
另外,paper[4]除了改良了attention计算方式以外,还调整了decoder的计算方式,简化计算,优化编码
-
→
→
差异:
- 改变了的计算方式,除concat外,dot、general 可以作为备选
- paper[3]中,由 组成,而最终计算时,仍需考虑 和 ,冗余
另外,是个RNN,在计算时,需要考虑,coding时,需使用for循环,会拖慢计算效率 - paper[4]中,仅由组成,最终计算 时,仅考虑,未冗余
另外,在计算时, 已知,coding时,可算出所有step的,进而计算所有的,所有操作都是向量化操作,不需使用for循环,会快很多 - 改变了的计算方式
paper[3]中,使用maxout作为最后的**函数, 即维度映射 + maxout
paper[4]中,使用softmax作为最后的**函数,即维度映射 + softmax
5. 需要注意的地方
- decoder 端的初始化: , 取encoder的反向RNN的初态的非线性,作为decoder的初态
- teacher forcing模式与测试时(生成模式)不同,所以训练过程不能完全都用teacher forcing,teacher forcing 与 生成模式应按比例分配
- beamsearch 只是在测试的时候用到
- 如果encoder 与 decoder 的序列都很长,显存装不下。可考虑对decoder端进行截断,分步优化(pytorch中 使用 state = state.detach())
- coding时,尽量别用for循环,会极大降低计算效率
6. 总结
paper[4]无论从理论结构,还是从coding上来看,都非常棒,计算细节赘述如下:
- 注: 可使用LSTM, GRu, Bi-LSTM 等
参考
- Sutskever I, Vinyals O, Le Q V. Sequence to sequence learning with neural networks[C]//Advances in neural information processing systems. 2014.
- Cho K, Van Merriënboer B, Gulcehre C, et al. Learning phrase representations using RNN encoder-decoder for statistical machine translation[J]. arXiv, 2014.
- Bahdanau D, Cho K, Bengio Y. Neural machine translation by jointly learning to align and translate[J]. arXiv, 2014. & ICLR, 2015.
- Luong M T, Pham H, Manning C D. Effective approaches to attention-based neural machine translation[J]. arXiv, 2015.