attention 理解 根据pytorch教程seq2seq源码

https://blog.csdn.net/wuzqchom/article/details/75792501

http://baijiahao.baidu.com/s?id=1587926245504773589&wfr=spider&for=pc

pytorch源码

这是李宏毅老师的ppt。右侧对应pytorch seq2seq源码。

我们的问题是,左边的数学符号,右侧的代码是如何对应的?

attention 理解 根据pytorch教程seq2seq源码attention 理解 根据pytorch教程seq2seq源码

 

 

 

1、attention 理解 根据pytorch教程seq2seq源码不是embedding,而是encoder的output attention 理解 根据pytorch教程seq2seq源码,如源码中的output。

为什么是output而不是hidden呢?这要从之后的train函数中看出。

attention 理解 根据pytorch教程seq2seq源码

train函数中设置了一个大的,全是零的encoder_outputs的矩阵,红线部分将encoder_output存储起来,而hidden只是在不断的循环。从PPT可以看出来,每次是需要全部的h1,h2,h3,h4........,那么肯定使用了encoder_outputs 这个大大的矩阵。故是output对应attention 理解 根据pytorch教程seq2seq源码,而不是hidden。

其次注意,这里的GRU,seq长度只是1。它的序列的扩展是通过train函数的for循环,依次遍历每个单词,来进行序列方向上的扩展。

 

 

2、李宏毅老师match函数,在源码中是怎么实现的?回答:是通过定义的一层神经网络来实现的。

attention 理解 根据pytorch教程seq2seq源码

attention 理解 根据pytorch教程seq2seq源码.

可以看出来,解码器有个self.attn的线性层,这个线性层就是我们要找的match函数。为什么呢?看attendecoderRNN的forward中,拼接两个向量,再进行linear层,且函数名是attn_weights。正好对应的上面绿色箭头的*2 

attention 理解 根据pytorch教程seq2seq源码所以,这里的attn_weights就是attention 理解 根据pytorch教程seq2seq源码

 

 

3、attention 理解 根据pytorch教程seq2seq源码又对应什么呢?答,对应代码是:

attention 理解 根据pytorch教程seq2seq源码torch.bmm是batch 的乘法操作,即1*1*10 与1*10*256的矩阵会变成1*1*256

 

 

4、attention 理解 根据pytorch教程seq2seq源码是什么呢?答Z0是encoder的最后一个输出隐藏层encoder_hidden。为什么呢?依旧从源码看出来

attention 理解 根据pytorch教程seq2seq源码

在for循环第一遍输入的时候,就将decoder_hidden送入其中。对应decoder的输入参数attention 理解 根据pytorch教程seq2seq源码

而decoder_hidden又是编码器最后一个状态输出。所以李宏毅老师说的initial_memory,我认为就是编码器最后一个隐藏状态。

 

 

5、Z1又是什么?回答是 attn_weight 与 输入的  德文单词  的词向量相乘后的结果。注意,train的时候可以使用真实的单词,即teaching forcing,故是 正确标注的德文向量。如果不开启的话,则将预测的德文单词作为输入,转换成embedding向量与attn_weight进行操作。对应的代码是这一行:

attention 理解 根据pytorch教程seq2seq源码

 

 

6、那么PPT上的输出翻译后的单词  对应代码哪一块呢?

attention 理解 根据pytorch教程seq2seq源码这个箭头,对应的attention 理解 根据pytorch教程seq2seq源码

attention 理解 根据pytorch教程seq2seq源码这个箭头,对应的attention 理解 根据pytorch教程seq2seq源码

因为使用了GRU。:)

以上只是个人理解,请指出错误