批标准化(BatchNorm)

注:本文部分参考自以下文章:
深入理解Batch Normalization批标准化
李理:卷及神经网络之Batch Normalization的原理及实现

原文链接:《Batch Normalizaion: Accelerating Deep Network Training by Reducing Internal Convariate Shift》
翻译、导读等推荐:12

1. BN目的

机器学习领域有个很重要的假设:独立同分布(IID,Independent Identically Distributed)假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。那BatchNorm的作用是什么呢?BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。

2. 内部协变量漂移(Internal Covariate Shift)

When the input distribution to a learning system changes, it is said to experience covariate shift
covariate shift问题是由于训练数据的领域模型 Ps(X) 和测试数据的 Pt(X) 分布不一致造成的,这里的下标s和t是source和target的缩写,代表训练和测试。

Mini-Batch SGD vs SGD(one sample):梯度更新方向准确、并行计算速度快,但需要调节很多超参数(学习率、初值等)。
各层权重参数严重影响每层的输入,输入的小变动随着层数加深不断放大。这就导致,各层输入分布的变动导致模型需要不停地去拟合新的分布。
于是,BN希望通过每层的输入均值、方差进行规范化,使输入分布一致

3. BN的思想

对于每个隐层神经元,把逐渐向非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。

如果我们能保证每次minibatch时每个层的输入数据都是均值0方差1,那么就可以解决这个问题。因此我们可以加一个batch normalization层对这个minibatch的数据进行处理。但是这样也带来一个问题,把某个层的输出限制在均值为0方差为1的分布会使得网络的表达能力变弱。因此作者又给batch normalization层进行一些限制的放松,给它增加两个可学习的参数 β 和 γ ,对数据进行缩放和平移,平移参数 β 和缩放参数 γ 是学习出来的。

备注:
由于sigmoid这类**函数,只有在0左右的邻域处,导数较大;所以,BN策略在一定程度上可以保证**函数的梯度一直较大,这避免了梯度消失问题;并且梯度够大表明训练速度较快。当然,由于这使得x大多落在sigmoid的线性区,而违背了当初使用sigmoid非线性变换的初衷,从而降低了表达能力,上述的参数β和γ在一定程度上可以解决此问题。

批标准化(BatchNorm)

4. BN 的预测(Inference)

虽然训练过程可以根据 Mini-Batch来获得统计量,但是预测时只有一个数据,无从计算合理的均值和方差。解决办法是,使用训练的所有数据(population)的均值和方差(用每个mini-batch的统计量计算得来即可)。
有了均值和方差,每个隐含层也有训练好的β和γ,就可以在预测过程中进行BN操作了

5. BN 的优势

论文中将Batch Normalization的作用说得突破天际,好似一下解决了所有问题,下面就来一一列举一下:
  
(1) 可以使用更高的学习率。如果每层的scale不一致,实际上每层需要的学习率是不一样的,同一层不同维度的scale往往也需要不同大小的学习率,通常需要使用最小的那个学习率才能保证损失函数有效下降,Batch Normalization将每层、每维的scale保持一致,那么我们就可以直接使用较高的学习率进行优化。
(2) 移除或使用较低的dropout。 dropout是常用的防止overfitting的方法,而导致overfit的位置往往在数据边界处,如果初始化权重就已经落在数据内部,overfit现象就可以得到一定的缓解。论文中最后的模型分别使用10%、5%和0%的dropout训练模型,与之前的40%-50%相比,可以大大提高训练速度。
(3) 降低L2权重衰减系数。 还是一样的问题,边界处的局部最优往往有几维的权重(斜率)较大,使用L2衰减可以缓解这一问题,现在用了Batch Normalization,就可以把这个值降低了,论文中降低为原来的5倍。
(4) 取消Local Response Normalization层。 由于使用了一种Normalization,再使用LRN就显得没那么必要了。而且LRN实际上也没那么work。
(5) 减少图像扭曲的使用。 由于现在训练epoch数降低,所以要对输入数据少做一些扭曲,让神经网络多看看真实的数据。

推荐阅读:深入解读Inception V2之Batch Normalization