论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》

核心思想

  本文提出一种基于数据增强的小样本学习算法(AFHN),利用生成对抗网络(GAN)实现数据集的扩充。数据增强的方法被认为可以增强类内样本方差的多样化,从而实现更加清晰地分类界限。先前的数据增强方法主要包含两类:一类是通过在基础数据集上学习一种变换映射,并将其直接应用到新的数据集上,得到映射后的合成图像用于数据扩充,这一类方法会破坏合成图像的区分能力(因为合成图像很粗糙,与原始类别并不相似);另一类方法是根据特定的任务生成对应的合成图像,这类方法保证了合成图像的区分能力,但特定的任务约束使得合成的图像容易陷入一种特定的模式,从而丧失了多样性(在GAN中这种情况称之为Mode Collapse,就是指生成的图像之间太过于相似,不具备多样性)。本文利用conditional Wasserstein Gener- ative Adversarial Networks ,cWGAN(与普通的GAN相比,cWGAN就是通过改进目标函数,进而提高训练稳定性的一个变种,此处不再详细介绍)生成样本,并通过增加分类正则项(classification regularizer)和 “反陷入”正则项(anti-collapse regularizer),解决了生成样本缺少区分能力和多样性的问题。本文提出算法的处理流程如下图
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》
  首先支持集图像和查询集图像经过特征提取网络FF得到对应的特征向量,支持集对应的特征向量为ss(如果有多个样本则取平均值),从[0,1]的均匀分布中采样得到两个随机变量z1,z2z_1,z_2。然后将特征向量ssz1,z2z_1,z_2输入到cWGAN的生成器GG中,得到合成的向量s~1,s~2\tilde{s}_1,\tilde{s}_2,过程如下
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》
将生成的s~1,s~2\tilde{s}_1,\tilde{s}_2与原始的ssz1,z2z_1,z_2输入到区分器DD中,并计算GAN损失LGAN{L}_{GAN},过程如下
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》
  而单纯的GAN损失并不能解决生成样本缺少区分能力和多样性的问题,因此本文又设计了两个正则化项:分类正则项(classification regularizer)和 “反陷入”正则项(anti-collapse regularizer)。其中分类正则项很好理解,首先利用softmax函数根据生成的样本s~\tilde{s}得到查询样本xqx_q对应类别的概率,计算过程如下
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》
式中q=F(xq)q=F(x_q),然后再利用交叉熵损失函数计算分类损失,作为分类正则项LcriL_{cr_i},该正则项的目的是为了增强生成样本的区分能力
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》
而“反陷入”正则项则是直接对两个合成特征向量的不相似度和产生它们的两个噪声向量的不相似度的比值进行惩罚,文字表述比较复杂,我们直接看公式
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》
式中,分子部分表示了两个合成特征向量之间的不相似度,而分母表示两个噪声向量之间的不相似度。有研究表明z1z_1z2z_2越相似,则s~1\tilde{s}_1s~2\tilde{s}_2越容易陷入同一种模式。当z1z_1z2z_2很相似时,也就是分母很小时,上式则相当于放大了s~1\tilde{s}_1s~2\tilde{s}_2之间的不相似度(因为要除以一个远小于1的数字)。该正则项的目的时为了增强生成样本的多样性。
  最后,将生成的样本s~\tilde{s}与原始样本ss一起输入到分类器CC中,进而实现对于查询样本xqx_q的分类。

实现过程

网络结构

  特征提取网络采用ResNet网络,生成器和区分器均采用带有Leaky ReLU**函数的两层MLP网络。

损失函数

  对于生成对抗网络部分损失函数如下
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》
值得注意的是“反陷入”正则项LarL_{ar}取了倒数,因此对于生成器而言是希望生成的s~1\tilde{s}_1s~2\tilde{s}_2之间的不相似度越大越好。
  对于分类器部分采用简单的分类损失函数进行训练
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》

训练策略

  本文的训练过程如下
论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》

创新点

  • 本文利用cWGAN网络生成样本,用于数据集扩充,改善小样本分类效果
  • 设计了两个正则化项,提高了生成样本的区分能力和多样性

算法评价

  本文还是比较标准的采用GAN生成样本,进而实现数据增强的算法。这一类方法通常因为样本太少,导致生成的样本效果太差,而无法起到数据增强的效果。而本文通过采用稳定性更好的cWGAN算法,并设计两个正则化项,改善了生成样本的效果,使其能够应用于小样本学习算法。

如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。论文阅读笔记《Adversarial Feature Hallucination Networks for Few-Shot Learning》