机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

本篇博客中我将分享一篇EMNLP2019与文本问答系统鲁棒性相关的论文:《QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization》。
论文下载连接
开源代码(添加注释;提取码:nq7f)

摘要

标准的准确率指标表明现在的阅读理解系统在很多数据集上都取得了很好的表现。然而,现有系统真正理解语言的程度仍是未知的,不能很好的区分干扰性句子/答案(和问题很像,单词重合度很高,但不能真正回答该问题的句子)。本文提出QAInfomax通过最大化段落(上下文)、问题及其对应答案之间的互信息,来对阅读理解系统进行正则化(形式上类似于机器学习中的L2正则化,详见下文)。QAInfomax能使模型不只是简单的学习问题和答案表面的关联。QAInfomax在 Adversarial-SQuAD数据集上取得了state-of-the-art的性能。

介绍

现有阅读理解模型仅仅利用数据表面的关联,类似一种模式上的匹配,并未真正理解语义。为了解决这个问题,研究人员提出了SQuAD数据集的对抗性版本,给每个段落添加一个干扰性句子(答案),来测试模型的鲁棒性。结果表明模型无法区分问题的实际答案和与问题有公共单词的干扰答案(过稳定问题),几乎所有SOTA的阅读理解模型在对抗性样本上性能都显著降低。

本文提出QAInfomax来最大化段落、问题及其对应答案之间的互信息,让模型在学习期间不会陷入数据的表面偏差。QAInfomax把deep infomax模型(其在图像、语音领域被证明有高效的表示学习能力)拓展到文本领域。来提高QA模型的鲁棒性,使其对干扰句子更加敏感,缓解过稳定问题。使QA系统生成的答案携带信息,该信息不仅能够解释问题,也能够解释答案自身。

互信息估计

两个随机变量X和Y之间的互信息定义如下:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization
联合分布p(X,Y)和边缘分布p(X),p(Y)乘积的KL散度。

MINE(互信息神经估计)通过训练一个分类器来区分来自联合分布的正样本(x,y)和来自边缘分布乘积的负样本(x,yˉ)(x,\bar{y}).MINE使用Donsker-Varadhan表示(DV)作为下界来估计互信息:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization
EP,ENE_P,E_N分别表示正,负样本的期望,g是一个判别函数,可以使用神经网络结构,输出一个实数值。

然而DV是一个很强/准确的下界,我们只需要最大化互信息(最大化下界),不需要得到精确的下界值。因此,Deep Info(DIM)使用Jensen-Shannon(JS)散度替代DV,它可以用2分类交叉熵损失高效实现:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

本文修改上式,通过交换x,y的角色:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization
其中(xˉ,y)(\bar{x},y)也是从边缘分布乘积采样的负样本。该式提供了更好的性能,更多关于互信息参数化的探索将是未来工作的重点。

方法

抽取式问答数据集如SQuAD,问题Q={q1,...,qK}Q=\{q_1,...,q_K\}的答案A=a1,...,aMA={a_1,...,a_M}是从段落P={p1,...,pN}P=\{p_1,...,p_N\}从抽取的一个范围{pm,...,pm+M}\{p_m,...,p_{m+M}\}。给定Q,P,其通过模型编码后的表示如下:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

现有的阅读理解模型,一般把段落的表示rpr^p通过一个单层神经网络,产生段落中每个token作为开始和结束位置的概率分布,然后和真实的开始位置和结束位置(one-hot向量),分别对应计算交叉熵损失Lstart,LendL_{start},L_{end},得到Lspan=Lstart+LendL_{span}=L_{start}+L_{end}. 通过最小化LspanL_{span}来优化模型。

QAInfomax实际上又增加了一个LinfoL_{info}损失,来对模型进行正则化,使模型不仅仅利用表明偏差,最大化段落、问答和答案之间的互信息。包含局部限制(LC)和全局限制(GC)。
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization
最大化互信息,需要来自联合分布的正样本和来自边缘分布乘积的负样本。

LC:最大化答案范围内某个token和该范围内其他token以及该范围周围几个token之间的互信息。 定义一个正样本(x,y):x是某一个样本答案范围内的一个token的表示 xra={rmp,...,rm+Mp}x \in r^a=\{r^p_m,...,r^p_{m+M}\},y是该样本答案范围内以及该范围周围C个token中除去x的某个token的表示y=ricrc={rmCp,...,rm+M+Cp}{x}y=r^c_i \in r^c=\{r^p_{m-C},...,r^p_{m+M+C}\} - \{x\},C表示考虑的上下文token数,是一个超参数。定义负样本(xˉ,y)(\bar{x},y)(x,yˉ)(x,\bar{y}),其中xˉ\bar{x}是随机采样的另一个样本答案范围内的一个token的表示xˉrˉa={rˉlp,...,rˉl+Lp}\bar{x} \in \bar{r}^a=\{\bar{r}^p_l,...,\bar{r}^p_{l+L}\},yˉ\bar{y}是随机采样的另一个样本答案范围内以及该范围周围C个token中除去xˉ\bar{x}的某个token的表示yˉ=rˉjcrˉc={rˉlCp,...,rˉl+L+Cp}{xˉ}\bar{y}= \bar{r}^c_j \in \bar{r}^c = \{\bar{r}^p_{l-C},...,\bar{r}^p_{l+L+C}\} - \{\bar{x}\}。LC计算如下:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

GC: 最大化答案范围内所有token的摘要 s 和段落以及问题中所有token(除去答案中的token)之间的互信息。定义一个正样本(x,y):x是某一个样本答案范围内所有token表示的摘要向量 x=s=S(ra)=σ(1Mria)x=s=S(r^a)=\sigma(\frac{1}{M}\sum r^a_i), σ\sigma是logistic sigmoid nonlinearity。y是某一个样本段落以及问题的所有token中某个token的表示(除去答案中的token),y=rir={rq,rp}{ra}y=r_i\in r=\{r^q,r^p\}-\{r^a\}.定义负样本(xˉ,y)(\bar{x},y)(x,yˉ)(x,\bar{y}),其中xˉ\bar{x}是随机采样的另一个样本答案范围内所有token表示的摘要向量xˉ=sˉ=S(rˉa)\bar{x}=\bar{s}=S(\bar{r}^a),yˉ\bar{y}是随机采样的另一个样本段落以及问题的所有token中某个token的表示(除去答案中的token),yˉ=rˉjrˉ={rˉq,rˉp}{rˉa}\bar{y}=\bar{r}_j\in \bar{r}=\{\bar{r}^q,\bar{r}^p\}-\{\bar{r}^a\}.GC计算如下:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

LinfoL_{info}计算如下:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

模型总损失计算如下:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

实验

C,α,β,γC,\alpha,\beta,\gamma分别为5,1,0.5,0.3。模型采用bert-base-uncased。判别函数g使用bilinear function:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

模型在原始SQuAD1.1数据集上训练,在SQuAD的对抗性版本数据集上测试其鲁棒性。

使用的评估指标是 ADDSENT 和 ADDONESENT:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization
模型结果:
机器阅读理解 | (4) QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

具体细节可以参见原文。