attention 理解 根据pytorch教程seq2seq源码
https://blog.****.net/wuzqchom/article/details/75792501
http://baijiahao.baidu.com/s?id=1587926245504773589&wfr=spider&for=pc
这是李宏毅老师的ppt。右侧对应pytorch seq2seq源码。
我们的问题是,左边的数学符号,右侧的代码是如何对应的?
1、不是embedding,而是encoder的output
,如源码中的output。
为什么是output而不是hidden呢?这要从之后的train函数中看出。
train函数中设置了一个大的,全是零的encoder_outputs的矩阵,红线部分将encoder_output存储起来,而hidden只是在不断的循环。从PPT可以看出来,每次是需要全部的h1,h2,h3,h4........,那么肯定使用了encoder_outputs 这个大大的矩阵。故是output对应,而不是hidden。
其次注意,这里的GRU,seq长度只是1。它的序列的扩展是通过train函数的for循环,依次遍历每个单词,来进行序列方向上的扩展。
2、李宏毅老师match函数,在源码中是怎么实现的?回答:是通过定义的一层神经网络来实现的。
.
可以看出来,解码器有个self.attn的线性层,这个线性层就是我们要找的match函数。为什么呢?看attendecoderRNN的forward中,拼接两个向量,再进行linear层,且函数名是attn_weights。正好对应的上面绿色箭头的*2
所以,这里的attn_weights就是
3、又对应什么呢?答,对应代码是:
torch.bmm是batch 的乘法操作,即1*1*10 与1*10*256的矩阵会变成1*1*256
4、是什么呢?答Z0是encoder的最后一个输出隐藏层encoder_hidden。为什么呢?依旧从源码看出来
在for循环第一遍输入的时候,就将decoder_hidden送入其中。对应decoder的输入参数
而decoder_hidden又是编码器最后一个状态输出。所以李宏毅老师说的initial_memory,我认为就是编码器最后一个隐藏状态。
5、Z1又是什么?回答是 attn_weight 与 输入的 德文单词 的词向量相乘后的结果。注意,train的时候可以使用真实的单词,即teaching forcing,故是 正确标注的德文向量。如果不开启的话,则将预测的德文单词作为输入,转换成embedding向量与attn_weight进行操作。对应的代码是这一行:
6、那么PPT上的输出翻译后的单词 对应代码哪一块呢?
这个箭头,对应的
这个箭头,对应的
因为使用了GRU。:)
以上只是个人理解,请指出错误