真正的完全图解Seq2Seq Attention模型

转载公众号:https://mp.weixin.qq.com/s/0k71fKKv2SRLv9M6BjDo4w
原创: 盛源车 机器学习算法与自然语言处理 1周前

https://zhuanlan.zhihu.com/p/40920384

作者:盛源车

知乎专栏:魔法抓的学习笔记

五分钟看懂seq2seq attention模型。

本文通过图片,详细地画出了seq2seq+attention模型的全部流程,帮助小伙伴们无痛理解机器翻译等任务的重要模型。

 

seq2seq 是一个Encoder–Decoder 结构的网络,它的输入一个序列,输出也是一个序列, Encoder 中将一个可变长度的信号序列变为固定长度的向量表达,Decoder 将这个固定长度的向量变成可变长度的目标的信号序列。--简书

 

好了别管了,接下来开始刷图吧。

大框架

 

 

真正的完全图解Seq2Seq Attention模型

想象一下翻译任务,input是一段英文,output是一段中文。

 

公式(直接跳过看图最佳)

输入: 真正的完全图解Seq2Seq Attention模型

输出: 真正的完全图解Seq2Seq Attention模型

(1) 真正的完全图解Seq2Seq Attention模型 , Encoder方面接受的是每一个单词word embedding,和上一个时间点的hidden state。输出的是这个时间点的hidden state。

(2) 真正的完全图解Seq2Seq Attention模型 , Decoder方面接受的是目标句子里单词的word embedding,和上一个时间点的hidden state。

(3) 真正的完全图解Seq2Seq Attention模型 , context vector是一个对于encoder输出的hidden states的一个加权平均。

(4) 真正的完全图解Seq2Seq Attention模型 , 每一个encoder的hidden states对应的权重。

(5) 真正的完全图解Seq2Seq Attention模型 , 通过decoder的hidden states加上encoder的hidden states来计算一个分数,用于计算权重(4)

(6) 真正的完全图解Seq2Seq Attention模型, 将context vector 和 decoder的hidden states 串起来。

(7) 真正的完全图解Seq2Seq Attention模型 ,计算最后的输出概率。

 

详细图

 

真正的完全图解Seq2Seq Attention模型

左侧为Encoder+输入,右侧为Decoder+输出。中间为Attention。

 

(1) 真正的完全图解Seq2Seq Attention模型 , Encoder方面接受的是每一个单词word embedding,和上一个时间点的hidden state。输出的是这个时间点的hidden state。

真正的完全图解Seq2Seq Attention模型

从左边Encoder开始,输入转换为word embedding, 进入LSTM。LSTM会在每一个时间点上输出hidden states。如图中的h1,h2,...,h8。

(2) 真正的完全图解Seq2Seq Attention模型 , Decoder方面接受的是目标句子里单词的word embedding,和上一个时间点的hidden state。

真正的完全图解Seq2Seq Attention模型

接下来进入右侧Decoder,输入为(1) 句首 <sos>符号,原始context vector(为0),以及从encoder最后一个hidden state: h8。LSTM的是输出是一个hidden state。(当然还有cell state,这里没用到,不提。)

(3) 真正的完全图解Seq2Seq Attention模型 , context vector是一个对于encoder输出的hidden states的一个加权平均。

(4) 真正的完全图解Seq2Seq Attention模型 , 每一个encoder的hidden states对应的权重。

(5) 真正的完全图解Seq2Seq Attention模型 , 通过decoder的hidden states加上encoder的hidden states来计算一个分数,用于计算权重(4)

真正的完全图解Seq2Seq Attention模型

Decoder的hidden state与Encoder所有的hidden states作为输入,放入Attention模块开始计算一个context vector。之后会介绍attention的计算方法。

下一个时间点

真正的完全图解Seq2Seq Attention模型

来到时间点2,之前的context vector可以作为输入和目标的单词串起来作为lstm的输入。之后又回到一个hiddn state。以此循环。

 

(6) 真正的完全图解Seq2Seq Attention模型, 将context vector 和 decoder的hidden states 串起来。

(7) 真正的完全图解Seq2Seq Attention模型 ,计算最后的输出概率。

真正的完全图解Seq2Seq Attention模型

另一方面,context vector和decoder的hidden state合起来通过一系列非线性转换以及softmax最后计算出概率。

 

在luong中提到了三种score的计算方法。这里图解前两种:

真正的完全图解Seq2Seq Attention模型

Attention score function: dot

 

真正的完全图解Seq2Seq Attention模型

输入是encoder的所有hidden states H: 大小为(hid dim, sequence length)。decoder在一个时间点上的hidden state, s: 大小为(hid dim, 1)。

第一步:旋转H为(sequence length, hid dim) 与s做点乘得到一个 大小为(sequence length, 1)的分数

第二步:对分数做softmax得到一个合为1的权重

第三步:将H与第二步得到的权重做点乘得到一个大小为(hid dim, 1)的context vector

 

Attention score function: general

 

真正的完全图解Seq2Seq Attention模型

输入是encoder的所有hidden states H: 大小为(hid dim1, sequence length)。decoder在一个时间点上的hidden state, s: 大小为(hid dim2, 1)。此处两个hidden state的纬度并不一样。

第一步:旋转H为(sequence length, hid dim1) 与 Wa [大小为 hid dim1, hid dim 2)] 做点乘, 再和s做点乘得到一个 大小为(sequence length, 1)的分数

第二步:对分数做softmax得到一个合为1的权重

第三步:将H与第二步得到的权重做点乘得到一个大小为(hid dim, 1)的context vector

 

完结