XLNet简介
背景知识
语言模型:自回归和自编码模式
图示:
黄色块为输入字符,蓝色块为字符的位置。
对于自回归语言模型,它希望通过已知的前半句预测后面的词或字。
对于自编码语言模型,它希望通过一句话预测被 Mask 掉的字或词,如上所示第 2 个位置的词希望通过第 1、3、5 个词进行预测。
自回归式的优缺点
计算效率比较高
只能编码单向语义
自编码式的优缺点
双向编码能力
BERT 假设要预测的词之间是相互独立的,即 Mask 之间相互不影响(不符合事实)
自编码语言模型在预训练过程中会使用 MASK 符号,但在下游 NLP 任务中并不会使用
更好的预训练语言模型
图片:
随机排列语言后,模型就开始依次预测不同位置的词(这种随机的分解顺序还能构建双向语义)。
这种排列语言模型就是传统自回归语言模型的推广,它将自然语言的顺序拆解推广到随机拆解。
不同顺序下预测X3
图片:
假设包含单词Ti的当前输入的句子X为:
x1,x2,x3,x4
现假设要预测的单词Ti是x3,如何使预测它的条件里也包含其后的单词呢?
可以这么做:
随机排列组合句子中的4个单词,在随机排列组合后的各种可能里,再选择一部分作为模型预训练的输入X。
理解:
形式上仍然是个自回归的从左到右的语言模型,但通过对句子中单词排列组合,把一部分Ti下文的单词排到Ti的上文位置中,于是同时应用了上文和下文。
如何实现
在Transformer内部,通过Attention掩码,从X的输入单词里面,也就是Ti的上文和下文单词中,随机选择i-1个,放到Ti的上文位置中,把其它单词的输入通过Attention掩码隐藏掉,于是就能够达成我们期望的目标(当然这个所谓放到Ti的上文位置,只是一种形象的说法,其实在内部,就是通过Attention Mask,把其它没有被选到的单词Mask掉,不让它们在预测单词Ti的时候发生作用,如此而已。看着就类似于把这些被选中的单词放到了上文Context_before的位置了)。
一般的掩码机制存在的问题
双流自注意力机制
图片:
注:白色圆圈为掩码
Content流自注意力
标准的Transformer的计算过程
Query流自注意力
假设要预测单词x3
既要抛掉[Mask]标记,但又不能看到x3的输入,于是就直接忽略x3的Content,只保留位置信息,称之为Query流
预训练语言模型更多的可能性
少样本学习
目前预训练方法在下游任务上依然需要相对多的样本来取得比较好的结果,未来一个重要的研究方向是如果在更少数据的下游任务上也能取得好效果。这需要借鉴一些 few-shot learning 的思想,不仅仅对从输入到输出的映射进行建模,还要对「这个任务是什么」进行建模,这也就意味着需要在预训练的时候引入标注数据,而不仅仅是无标注数据。
怎样在 Transformer 架构基础上构建更强的长距离建模方式
怎样加强最优化的稳定性
研究者发现在训练 Transformer 时,Adam 等最优化器不是太稳定。例如目前在训练过程中,一定要加上 Warm up 机制,即学习率从 0 开始逐渐上升到想要的值,如果不加的话,Transformer 甚至都不会收敛。这表明最优化器是有一些问题的,warm up 之类的技巧可能没有解决根本问题。
怎样用更高效的架构、训练方式来提升预训练效果。
例如最近天津大学提出的 Tensorized Transformer,他们通过张量分解大大降低 Muti-head Attention 的参数量,从而提高 Transformer 的参数效率。
编码器-解码器的一体化
XLNet的作者杨植麟表示,XLNet 的另一大好处在于它相当于结合了编码器和解码器。因此理论上 XLNet 可以做一些 Seq2Seq 相关的任务,例如机器翻译和问答系统等。
首先对于 Encoder 部分,XLNet 和 BERT 是一样的,它们都在抽取数据特征并用于后续的 NLP 任务。其次对于 Decoder,因为 XLNet 直接做自回归建模,所以它对任何序列都能直接输出一个概率。这种 Decoder 的性质是 BERT 所不具有的,因为 BERT 所输出的概率具有独立性假设,会有很多偏差。
「如果我们用 XLNet 来做机器翻译,那么一种简单做法即将 Source 和 Target 语言输入到 XLNet。然后将 Target 那边的 Attention Mask 改成自回归的 Attention Mask,将 Source 那一部分的 AttentionMask 改成只能关注 Source 本身。这样我们就能完成 Seq2Seq 的任务。」
实践
https://colab.research.google.com/drive/1k3HV_SrLNzi_TS76JSe0Oy7h8eBI2HRC?hl=en#scrollTo=VRqUMaLlrJ-y