Big Transfer (BiT)论文阅读笔记

这篇论文讲的是迁移学习在图像分类任务中的应用,作者强调这是一种通用型的迁移学习,也就是说这种方法不会为特定的数据集做特殊的处理,不同等级的预训练模型在往其他数据集上迁移时均采用相同的处理的方法,以此来证明BiT这种方法的普适性。

方法

上游预训练模型

上游预训练的模型规模体现在训练数据的大小,而不是模型的大小。作者试验了几种模型,默认采用ResNet152x4。论文中按照数据的大小,分别训练的BiT-S、BiT-M和BiT-L三种规模的预训练模型,分别对应的数据集是ILSVRC-2012(1.3M)、ImageNet-21K(14M)和JFT(300M)。

预训练模型采用了Group Normalization和Weight Standardization。作者给出了不使用Batch Normalization的两个理由:1. 训练模型会使用分布式,使用BN的话无法利用到大Batch Size的优势,因为不同卡之间没有同步;2. BN需要更新运行中的数据,不适合用于任务迁移。

迁移到下游任务

训练好上游网络后,需要把网络fine-tune到下游任务中,论文采用了一种叫做BiT-HyperRule的启发式方法去选择和调整几个重要的训练超参——训练周期长度、数据分辨率和是否使用MixUp数据增强。这种方法是通用的,只会在某些不同情况下做一下区分,调整的方法如下:

  1. 数据分辨率。小于 96 × 96 96\times96 96×96的分辨率,先将图片resize到 160 × 160 160\times160 160×160,再随机剪切出 128 × 128 128\times128 128×128的方框;对于大于 96 × 96 96\times96 96×96的分辨率,先将图片resize到 512 × 512 512\times512 512×512,再随机剪切出 480 × 480 480\times480 480×480的方框。
  2. 训练周期长度。对于小任务数据集(小于20K),训练500个step;对于中等数据集(大于20K小于500K),训练10K个step;对于大数据集(大于500K),训练 20K个step。在训练总步长的30%、60%和90%时分别降到原先的0.1。
  3. MixUp数据增强。在中型和大型数据集上使用,小型数据集上不使用。其中, α = 0.1 \alpha = 0.1 α=0.1

在训练时,数据的预处理包括resize到正方形,随机剪出一个小一些的正方形,再随机水平翻转;在测试时只需将图片resize到一个固定的大小。在训练方式上,数据的输入分辨率和Fine-tune借鉴了FixRes的训练方法。

另外,作者发现,在下游任务fine-tune时,并不需要使用其他正则方法,包括优化器的weight decay或者DropOut。

实验

在上游预训练模型上,对于任何的数据集,论文均使用ResNet-152架构,其中基本Block的中间层channel会放大4倍。论文使用统一的训练超参:带momentum的SGD,初始学习率为0.03,momentum为0.9;输入数据的大小为 224 × 224 224\times224 224×224,BiT-S和BiT-M训练了90个epoch,在第30、60和80代时将学习率除以10,BiT-L训练了40个epoch,在第10、23、30和37代时将学习率除以10。Batch size是4096,在谷歌云上使用了512个TPUv3,每张卡8张图片。训练开始阶段还使用了warm-up的方法在前5000个step上逐渐将学习率线性上升到初始值(0.03 * batch_size / 256)。预训练时,优化器的weight decay设为0.0001。

在下游任务迁移的Fine-tune中不使用weight decay,仍然采用SGD,初始值设为0.003,momentum为0.9,batch size为512。其他的超参设置参考上一个部分的调整策略。

在实验上,这里只贴出BiT-L模型迁移到一些常见的图像分类数据上的结果。如表格1所示。其他的实验和一些分析请读者自行阅读论文的实验部分。

表格1. BiT-L模型在一些分类数据集上Top-1准确率
Big Transfer (BiT)论文阅读笔记

从表1中可看出,在JFT-300M数据集上预训练的BiT-L模型迁移到其他数据集时,相比之前通用方法的SOTA,均有不小的准确率提升,证明了BiT方法在多种任务上的普适性。但是相比于专门方法的SOTA,准确率还差一些。这里的通用方法指的是模型和算法采用了与数据任务无关的相同方法(BiT也是),专门方法指的是在某个数据集上训练时,有条件地使用了只适合该数据集的特殊处理方法(对其他数据集并不适用)。