Unsupervised Data Augmentation for Consistency Training

论文地址:https://arxiv.org/pdf/1904.12848v4.pdf

官方实现代码(tensorflow):https://github.com/google-research/uda

同样来自于谷歌的,偶然看到的,关于无监督数据增强方法

文章主要在三种任务上进行了相关实验:文本分类、图像分类、迁移学习

创新点:

  • 监督学习中的数据增强方法在半监督学习中同样可以用来对无标签数据进行数据中增强
  • 主要提出了一个名为Unsupervised Data Augmentation (UDA)无监督数据增强方法。

数据增强一直在监督学习中起着锦上添花的作用,因为到目前为止数据增强通常是用在数据集相对比较小的标记数据集上,以达到扩充数据集的多样性的作用,但是数据增强起到的作用依然是受限的。基于此,我们在一致性训练(即原始输入图片和添加噪声色图片,对模型的输出没有影响,输出是一致的)框架下,把这些监督学习中优秀的数据增强方法扩展到半监督学习任务当中。

当前半监督学习中,利用无标签数据去进一步平滑模型的方法,主要归纳为以下两步

Unsupervised Data Augmentation for Consistency Training

先给一个输入x,然后输出分布Unsupervised Data Augmentation for Consistency Training,再给一个添加了噪声的x,输出分布为Unsupervised Data Augmentation for Consistency TrainingUnsupervised Data Augmentation for Consistency Training,最后最小化以上两个分布的距离。这个过程有两点好处:1.会让模型对抗噪声的能力得到提高,当输入发生改变的时候,输出不会发生大的变化,会比较平滑,2.可以把标签信息从标签数据传递无标签数据中。

Unsupervised Data Augmentation for Consistency Training

上图呈现了对于UDA方法训练目标框架,其中Unsupervised Data Augmentation for Consistency Training是经过数据增强方法(RandAugment)增强过的无标签数据.M为当前的训练的预测模型。总损失=标签数据的交叉熵损失+无标签数据的一致性损失,总损失公式如下:

Unsupervised Data Augmentation for Consistency Training

其中Unsupervised Data Augmentation for Consistency Training的设置是为了平衡监督损失和无监督损失,在该文章中,该参数设置为1.Unsupervised Data Augmentation for Consistency Training中的参数是从当前模型中复制过来的,并且这一部分不进行反向梯度更新。总的来说,前半部损失是为了分类,而后半部分损失则是为了提高系统的鲁棒性。对此,因为前半部分的标签数据比较少,而后面的无标签数据比较多,所以前半部分必定会随着训练的增加,发生过拟合。为了防止这种过拟合,文章提出了一种Training Signal Annealing (TSA)的方法,该方法仅仅只针对标记数据。通过动态改变阈值来防止过拟合。其具体操作过程如下:

Unsupervised Data Augmentation for Consistency Training

其中K表示训练数据类别数,T为总的训练步数,t为当前训练步数,当预测值Unsupervised Data Augmentation for Consistency Training高于阈值Unsupervised Data Augmentation for Consistency Training的时候,就从总损失中移除该样本的损失。三个图分别代表三种不同的调节方法:对数、线性、指数。

三种函数的适用条件:

  • 当模型容易过拟合时,即模型会在很短的时间内对样本做出高概率的预测,这时我们就期望阈值的增长更慢一下,这样可以删掉更多容易训练的样本,因此可以采取 exp指数函数
  • 当模型很难过拟合,即模型会花费较长时间才能对样本做出高概率的预测,这样相同时间内,模型能够做出高概率预测的样本就比较少,此时需要删掉的样本也比较少,因此我们期望阈值在短时间内会比较大,这样删掉的样本就比较少,因此可以采取 log 对数函数
  • 对于一般的样本,直接采用均匀增长的线性函数就可以。


 

实验结果:

Unsupervised Data Augmentation for Consistency Training

结论:在监督学习中性能好的增强算法,在半监督学习中,同样适用。

Unsupervised Data Augmentation for Consistency Training

 

  • VAT: 采用高斯噪声;
  • MixMatch除了使用标准的crop, flip之类的增强外,还使用了MixUp增强方法
  • UDA 采用的数据增强方法为 Random Augment,这种增强方法是正态地从PIL图片转换库提取,这样就比以上的方法具有更多的组合
  • 三条曲线都呈下降趋势,即标签样本越多,分类效果越好
  • 在相同的标签样本情况下,UDA的分类效果最好

在Imagnet上的结果:

Unsupervised Data Augmentation for Consistency Training

TSA评估:

Unsupervised Data Augmentation for Consistency Training

  •