MegDet:大mini-batch 检测器

MegDetface++ 提出的通用检测框架。整体结构为类似于faster RCNN2层结构,基础网络类似于Resnet50

获得了2017 COCO检测第一名,发表于cvpr 2018

文章主要讲解大mini-batch这个训练技巧,实现了在大mini-batch下的精度提升1.5个点。

MegDet:大mini-batch 检测器

mini-batch的缺点:

(1)训练时间太长。

(2)小的batch对于batch normalization 层的参数(mean,var)计算很不利。(mean,var)肯定是batch越大,计算的误差越小,越接近整体数据的(mean,var)。Group Norm那篇文章中实验的batch norm的最小值是16比较OK,小于16batchsize会使得(mean,var)不准确。

(3)对于检测框架中,小的batch中会存在正负样本严重不平衡的现象。从下图中可以看出,256-batch会比16-batch具备更大的正负样本比例。

MegDet:大mini-batch 检测器

下图(a)(b)为正负样本比例较少的情形,(c)(d)为正负样本比例较大的情况。可以看出,当正负样本比例较大时,正样本周围的绿色框会更多更集中,更有利于对目标位置的回归。

MegDet:大mini-batch 检测器

mini-batch需要大的学习速率:

文章基于2方面做了解释。

(1)基于Linear Scaling Rule 原则,对于进行多机器训练的大的mini-batch,假设batch size增加了k倍,准确的说应该是每个机器吃的batch size不变,增加了k倍的显卡,那么学习速率也应该增加为原来的k倍,即learning_rate_hat=k*learning_rate

(2)基于方差等价的原则。

原始batch size的方差计算如下:

MegDet:大mini-batch 检测器

batch size 扩大为k 倍时,方差如下:

MegDet:大mini-batch 检测器

可以看出,当batch size 扩大为k 倍时,方差减少为原来的1/k。这时为了想要得到和原来一样的方差,就得将现在的方差乘以k倍,但是现在的方差已经是确定的了。怎么办呢?

这时可以将学习率和现有的反差一起计算方差。

MegDet:大mini-batch 检测器

如上面式子,要想上下相等,只能是r_hat=k*r。即将学习率提高为原来的k倍,即learning_rate_hat=k*learning_rate

写到这里还有一个问题,为什么可以将learning_rate*方差一起计算他们的方差。

不带正则化的梯度跟新的公式是weight_decay=weight_decay-learning*step_weight

 

热启动策略(Warmup Strategy ):

MegDet:大mini-batch 检测器

不是简单的直接将学习速率提高k倍,而是基于线性增长的原则(linear scaling

rule ),将学习速率一点一点的增大。直到达到k倍学习速率。

GPUBatchNorm

MegDet:大mini-batch 检测器

从上图可以看出,batch norm梯度跟新的策略为主从式的跟新。就是说,每个worker将计算mean,var的参数都传递到一台主机,主机计算完mean,var再分发给每个worker。感觉不像Horovod这种 ring-all-reduce的环状跟新策略好。

 

References:

MegDet: A Large Mini-Batch Object Detector