【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

   【CVPR2019】SpotTune: Transfer Learning through Adaptive Fine-tuning 

论文链接:SpotTune: Transfer Learning through Adaptive Fine-tuning

一. Introduction 

使用深度学习模型时,微调(fine-tune)是应用最普遍的迁移学习方法。它具体指先在源任务上获得预训练模型,然后在目标任务上进一步训练它,从而,可以减少对目标标签数据需求的同时,提升模型的性能。

常用的微调方式有以下两种:第一个是使用目标数据集优化预训练模型中的所有参数,它的一大缺陷是,当目标数据集小且预训练网络的参数过大时,可能会产生过拟合;第二个是依据目标任务中训练集有限以及初始层学到的低级特征可以在多个任务间共享这一经验,选择微调深度网络的最后几层的参数,冻结前面其他层的参数,但是由于需要手动选择初始冻结层数,这不利于提升优化效率。并且,像ResNet这种由多个浅层网络集成的模型,初始层学到的低级特征可以共享这一前提不再适用,所以仅是微调模型的最后几层并不一定是最优的选择。

目前的方法也均是采用全局微调的策略,即,对目标任务中的所有样本采取(在某些网络层)freeze参数或者是fine-tune参数的决定。这就相当于假设该决定对整个目标数据分布是最优的,但是,现实往往并非如此。

例如,目标任务中的某些类与源任务之间的相似性较高,这些类的样本可能倾向于finetune较少的预训练参数,与之相反的样本则希望能finetune更多的预训练参数,以达最好的准确率。

所以,理想的情况是,为目标任务中的每一样本,在每一层,都制定一个该finetune还是该freeze参数的决策。

就如图1所示,上面的是在源任务上得到预训练模型,下面,在目标任务中,有两个猫的训练样本,第一个猫样本在前两块选择冻结参数,也就是保留预训练模型原有的参数,后两块做了微调,而第二个猫样本在第一三块选择了微调,在二四块选择冻结参数。而这样的选择对他们来说是达到了最优的微调策略。

【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

本文提出了一种方法SpotTune。它可以学习依赖输入(input-dependent)的微调策略,大体指从一个轻量级神经网络的输出所构成的离散分布中采样,来为每一个样本决定在哪一层该fine-tune,哪一层该freeze

由于策略函数是离散的,不可微,所以,采用了Gumbel Softmax 采样方法来训练策略网络。

在测试期间,策略网络就可以决定来自上一层的特征(feature)该进入原预训练的网络层,还是需做微调的网络层。

本文贡献:

提出了一个依赖输入的微调方法,能为每个目标样本自动决定在哪些层fine-tune

还提出了上述方法的一个global变体,即,约束所有样本在相同的k层做fine-tune,其中,这k个层可以分布于网络中的任意部分。该变体可以使最终模型有较少的参数。

通过大量实验证明,本文提出的方法在14个数据集中有12个超过了标准fine-tune方法,并且,在Visual Decathlon Challenge(10个用于多域学习算法性能测试的基准数据集),相比其他先进的方法,取得了最高的score

二. Proposed Approach

本文提出的方法能应用于不同的神经网络架构,但是由于ResNet相当于是多个较浅的层组合而成,使它对残差块之间的交换具有弹性,也就是,交换残差块对网络性能影响不大。该性质更合乎本文提出的方法。所以,接下来的实验,均是基于ResNet网络架构。下边这个图是resnet的一个基本残差块形式:

                                                                                【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

假定ResNet预训练模型的第l块的输出表示为:

                                                           【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

为了在训练期间,决定某一residual块,是否被fine-tune,先freeze了该原始块【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning,再创建一个与它并排的初始参数相同的trainable的块【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning。此时,第l层的输出可以表示为:

                                                 【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

其中,【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning是一个二进制随机变量,可以为输入图片指示该residual块是被微调还是frozen 它是从一个轻量级的策略网络输出所构成的离散分布中抽样所得,取值若为0,表示重用第【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuningfrozen块,若为1,表示通过优化【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning来微调第【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning个块。

         【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

2文中提出的SpotTune方法架构的图例说明,上面的黄色块表示策略网络,下面的两排表示预训练模型,浅棕色块的这一排表示不做微调,对应于式2中的F,深棕色表示做微调,对应于式2中的F尖,通过策略网络得到微调策略I(x)I(x)的取值就可以决定每个残差块前面的开关的开合,从而决定上一层的输出接下来该选择走微调的残差块还是冻结的残差块。

对于策略网络,它是一个轻量级的resnet网络,由于预训练ResNet模型有L个残差块,所以,它的输出logits就是一个L*2的二维矩阵,然后通过Gumbel-max采样得到微调策略I(x),它是一个L维的向量,取值不是0就是1Gumbel-max采样过程可以分为4步,分别是:

Gumbel-max的采样过程可简述如下(参考自【一文学会】Gumbel-Softmax的采样技巧

  • 对于网络输出的一个zz维向量【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning,生成zz个服从均匀分布U(0,1)的独立样本【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning,z表示类别数,由于策略网络只有两类(fine-tune or freeze参数),所以z=2;
  • 通过【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning计算得到【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning
  • 上述两步结果对应相加,得到新的向量【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning
  • 最后通过argmax取上述向量最大值的索引。

4步可以简化为式3的表示方式:

                                                      【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

由于Gumbel-Max的采样结果(不是0就是1),是离散的,不可微,也就不能用于反向传播优化网络参数,所以,在反向传播时,作者采用了Gumbel-softmax采样方式,也就是将式3中的argmaxsoftmax替换,得到式(4),其中,????τ是控制输出向量Y离散程度的参数,当它逼近于0时,生成的分布就逼近于离散分布,当它越大时,可以使生成的分布越平滑。所以,当????τ>0时,Gumbel softmax分布是平滑的,这样就可以解决了之前的无法反向传播优化网络的问题。

                                                         【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

除了特定图像的微调策略,作者还对其做了扩展,提出了它的一个全局变体,也就是,限制所有的图像在ResNet的相同的k个块作微调,这k个块可以分布在resnet的任何部位。为了实现这一变体,作者引入了两个损失函数,分别是式5和式6

                                                            【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

                                                           【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

5中的【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning是指对第【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuningresidual块,目标数据集中选择了fine-tune的图像比例,取值为01。它可以使所有训练样本趋向于选择k个微调块。式6可以迫使【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning精确到01,这样,就可以使所有图像在第l块,要么全部微调,要么都不微调,从而保证了所有图像在相同的k个块做了微调,并且这k个块可以分布于预训练模型的任何部位。最后,将这两个损失函数与分类损失函数结合,就可以得到最终的损失函数,式7该变体相比于手动选择k个块,能实现最好的准确率。并且,由于它在测试阶段不需要策略网络,且k被设为一个较小的数时,它可以减少内存占用和计算成本。

三. Experiments

数据集:

作者通过实验比较了SpotTune方法与其他微调方法和正则化方法的效果,数据集包含两部分,第一部分就是表1中列出的5个数据集,其中,前3个是细粒度分类基准,后面两个数据集较大,并且与ImageNet不匹配。第二部分用于评估来自多个域的图像的视觉识别算法的数据集,包含10个。为了减少计算负担,这10个数据集中的图片的长宽均调成72pixels。

                                    【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

                              【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

度量方式

                             【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

Baselines:将SpotTune与下面的几种微调和正则化技术做了比较实验

                                     【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

Pre-trained model(预训练模型)

                                     【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

Policy network architecture(策略网络架构)

                                     【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

SpotTune vs. Fine-tuning Baselines:

                 【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

                                               【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

表2中列出了spottune方法和其他几种微调方式在第一部分数据集上的测试结果,很明显,文中提出的SpotTune方法在几个数据集上基本上都超过了其他方法的性能,只有,在WiKiArt数据集上,比微调resnet-101略低,作者推测这是因为这个数据集的训练样本比较多,所以,作者只选取其中的25%训练样本和10%训练样本,再次比较了两种方式,结果就是上面的绿色区域,可以看出,减少了训练样本,spottune的性能胜过了微调resnet101,并且随着样本数减少,差距越来越大。其次,可以看出,只是微调后面1个或2个或3个残差块的效果均不如标准微调方式。

结果中的第一行是将预训练网络当作特征提取器,当应用于目标数据集时,它能减少参数量,但是由于域转换,致使网络性能下降。

这个正则化方法的结果非常接近于spottune的结果,但是,作者在文中提到,可以将它用于补充spottune方法,两者结合,应该能得到一个更好的结果。

Visualization of Policies:to

                                               【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

为了能更好地理解,策略网络学到的微调策略,作者将第一部分数据集上对应的每一个残差块的策略做了可视化,如图3,从下往上,每一横排代表第多少个残差块,每一列代表一种数据集,每一方块的颜色深浅代表对应数据集中,在该残差块选择了做微调的图像所占比例,占比越大,颜色越深。从图3中可以看出,不同的数据集有不同的微调策略,而SpotTune能为每个数据集,甚至每个样本自动地确认恰当的微调策略。

Visualization of Block Usage:

                                    【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

此外,作者还对测试时, 每个数据集使用的微调块的数量的分布做了研究,上面图4就是其结果,纵轴表示测试样本数量,横轴表示做了微调的残差块数量,通过不同的颜色表示几种数据集,比如,图中红色椭圆圈出的部分表示,Flowers测试集中有大概1500个样本在6个残差块中做了微调。可以从中看出,对于每一种数据集,不同的图像倾向于使用不同数量的微调残差块。再次佐证了特定图像的微调策略能比所有图像的全局微调策略的准确率更高。

                                              【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

上面的图5,展示了几张CUBSflowers数据集中,使用较少微调块的图像样本和使用较多微调块的图像样本,第一排的图都是使用微调块较少的图,可以看出,它们的背景比较干净,下面这排是使用微调块数目多的图,它们的背景相对比较复杂些。

Visual Decathlon Challenge:

           【论文笔记】SpotTune: Transfer Learning through Adaptive Fine-tuning

表3列出了spotTune和其他方法在第二部分数据集上的实验结果对比。可以看出,spottune方法基本超过了其他所有方式。与黄色标出的标准微调方式相比,spottune的参数量与它相近,但是,最终的得分3612远超过标准微调方式的得分3096,再加上第一部分数据集的实验结果,spottune14个数据中的12个上,超过了标准微调方式,只有红色线划出的两个数据集不如标准微调方式。倒数第二行的全局变体方法,在这里设定k=3,它相对于spottune方法,参数量有大幅度减少,并且分数为3401,仅次于spottune

四. 总结

提出了一种自适应微调算法,SpotTune。它是针对于目标数据集中的每一个样本的微调策略。并且在大量的数据集上验证,SpotTune的性能基本上超过了常用的几种微调方式。