深度学习建模训练总结(八):如何处理梯度消失(爆炸)

在讨论如何处理梯度消失梯度爆炸的问题之前,先来看看梯度消失梯度爆炸的成因。

梯度消失和梯度爆炸本质上是一样的,都是反向传播算法造成的。我们知道,一个神经网络就是一个一层层嵌套的非线性函数,假设现在模型一共有四层,我们求第二层的一个参数更新的梯度:

δ w 2 = ∂ L o s s ∂ w 2 = ∂ L o s s ∂ w 2 ∂ f 4 ∂ f 3 ∂ f 3 ∂ f 2 ∂ f 2 ∂ x 2 \delta w_2 = \frac{\partial Loss}{\partial w_2} = \frac{\partial Loss}{\partial w_2} \frac{\partial f_4}{\partial f_3} \frac{\partial f_3}{\partial f_2} \frac{\partial f_2}{\partial x_2} δw2=w2Loss=w2Lossf3f4f2f3x2f2

其中,f4关于f3的偏导,就是对**函数求导,当该部分大于1时,同时层数越多,对于接近输入的层,求出的梯度更新就会以指数形式增加,反之当其小于1,就会指数衰减,造成梯度爆炸和梯度消失。具体表现上,当更新的权重是一个十分大的数值,参数的剧烈变化会导致误差一下子上升,甚至变为无穷,而梯度消失,可能相对难察觉,一个可能的现象就是loss几乎保持不变,因为这时候接近输入层的节点已经几乎不再更新了。

深度学习建模训练总结(八):如何处理梯度消失(爆炸)
可以看看上图,越接近输入层,梯度更新的速度就越慢,这是反向传播造成的,这里仅仅只有四层,可以预见如果模型有十几层,那么接近输入的隐层将十分难进行学习。

除此之外,**函数的选择也会导致梯度消失的问题,还是因为链式求导会把每层求导结果相乘,当每层的求导小于1,就会导致相乘之后越来越接近0,而在**函数中,sigmoid函数的求导恰恰一定小于0.25,这就导致了如果模型使用了过多的sigmoid,就会导致梯度消失的出现,除此之外,tanh也是类似的原因,因为它的求导后函数最大值为1,所以也可能导致梯度消失,只是相对sigmoid更缓和。

接下来就来看看有哪些措施可以缓解梯度消失(爆炸)的问题。

第一个是梯度剪切,主要针对梯度爆炸,思想是设置一个梯度的剪切阈值,梯度更新的时候,如果梯度超过这个值,就把其强行限制在这个范围内,防止梯度爆炸。

第二个是权重正则化,也是针对梯度爆炸,假设我们现在有MSE损失函数:

L o s s = ( y − W T x ) 2 Loss = (y - W^T x)^2 Loss=(yWTx)2

权重正则化就是对损失函数增加一个惩罚项:

L o s s = ( y − W T x ) 2 + a ∣ ∣ W ∣ ∣ 2 Loss = (y - W^T x)^2 + a||W||^2 Loss=(yWTx)2+aW2

这时候,对于模型来说,梯度下降的目的就变成了既要减小loss,也要减小权重,也就避免了梯度爆炸导致更新的梯度过大,导致更新后的权重变得过大。

第三个是关于**函数的选择,上面也提到,**函数的选择对于梯度爆炸消失有很大的影响,这里介绍两个可以缓和梯度消失的**函数,第一个是ReLu:

y = m a x ( 0 , x ) y = max(0, x) y=max(0,x)
y ′ = 1 ( x > 0 ) y' = 1 \quad (x>0) y=1(x>0)
y ′ = 0 ( x < = 0 ) y' = 0 \quad (x<=0) y=0(x<=0)

可以看出,当x大于0的时候,导数恒等于1,所以这时候就不会存在梯度消失或者梯度爆炸的问题,顺带一提,因为ReLu是线性函数,所以求导速度比sigmoid、tanh都要快很多,所以在模型中使用relu可以让训练速度加快。除此之外,注意到当x<0时ReLu求导等于0,这就直接导致更新的梯度为0,就会导致一些神经元无法**,虽然可以通过设置小一点的学习率改善这个问题,但也因此不能完全说ReLu就解决了梯度消失问题,但是对于更新的梯度为0这一点,对于研究网络压缩、神经元剪切而言又是有价值的。总的来说,Relu目前是使用比较广泛的一个**函数。

针对x小于0会导致Relu梯度为0的问题,又提出了leakyRelu,主要就是针对x小于0的情况做一个改善:

y = m a x ( k x , x ) y = max(kx, x) y=max(kx,x)

k是leaky系数,小于1大于0,当x小于0的时候,y=kx,梯度也等于k。

第四个介绍的是batch normalization,目前来说BN是最常用的用于解决梯度消失、梯度爆炸的方法,最简单的BN就是计算出batch样本的均值方差,然后对数据进行标准化处理,把数据的分布转化为正态分布。

BN的目的主要是考虑到一个问题,目前神经网络进行训练,经过非线性变换,不同层的数据分布是不同的,而且每轮经过模型反向传播后,分布又改变,也就是所谓的internal covariate shift,在分布十分复杂且不断变化的情况下,模型就很难学习,所以就有人提出,通过batch normalization的方法,把每一层(或者指定几层)的数据转换为正态分布,这样模型就更容易收敛了。

以上是BN层的一个优点,而对于梯度消失问题,他主要的作用是在**函数方面,举个例子,假设我们使用sigmoid**函数,

深度学习建模训练总结(八):如何处理梯度消失(爆炸)

可以看到,对sigmoid而言,如果输入的数据小于-4或者大于4,那么对应的导数就接近0,而通过BN层,数据分布转化为均值为0的正态分布,计算的梯度就能大概率远离0,避免造成梯度消失。

但是,单纯的标准化处理是不够的,有时候我们通过深层神经网络就是希望学习一个复杂的分布,模型十分艰难学习了这个分布,然后你又加一个标准化处理,把这个分布又变回正态分布,岂不是之前的功夫都是白费的,所以就有人进一步提出,需要在标准化处理之后加上shift操作:

y = s c a l e ∗ x + s h i f t y = scale * x + shift y=scalex+shift

其中的参数可以通过模型学习,当然本质上这也终究是一个平衡,一方面我们希望模型能学习到一个复杂的分布,另一方面为了训练过程中避免梯度消失、收敛速度更快,我们又不希望模型的分布过于复杂,所以只能折中一下。

最后一个要介绍的方法是残差连接,残差连接主要可解决梯度消失的问题。

深度学习建模训练总结(八):如何处理梯度消失(爆炸)

x l + 1 = x l + F ( x l ) x_{l+1} = x_l + F(x_l) xl+1=xl+F(xl)

可以看到,输出除了包含输入的非线性变换之外,也加入了原始输入,两者做按位相加,在这里,有可能输入x做了非线性变换之后,维度发生了变化,这种变化可能是空间上,也可能是深度上,对于空间的不同,如果输出F(x)小于输入x,则对输出F(x)补0,否则对输入x做一个线性变换实现空间上的扩张:

x l + 1 = W ∗ x l + F ( x l ) x_{l+1} = W*x_l + F(x_l) xl+1=Wxl+F(xl)

对于深度上的不同,只需要通过1*1的卷积层进行升维即可。

上面用x表示是为了和图片的公式相对应,但看起来总感觉像是输入的x,为了避免误会,下面用Si表示第i层的节点,Fi表示对第i层的非线性变换,假设我们的网络由多个残差连接构成

S L = S L − 1 + F L − 1 ( S L − 1 ) = S L − 2 + F L − 2 ( S L − 2 ) + F L − 1 ( S L − 1 ) = S l + ∑ i = l L − 1 F i ( S i ) S_{L} = S_{L-1} + F_{L-1}(S_{L-1}) = S_{L-2} + F_{L-2}(S_{L-2}) + F_{L-1}(S_{L-1}) = S_{l} + \sum _{i=l} ^{L-1} F_i(S_i) SL=SL1+FL1(SL1)=SL2+FL2(SL2)+FL1(SL1)=Sl+i=lL1Fi(Si)

上式表明,假设多个神经层都有残差连接,那么浅层的输入就会一直传递到深层,而信息没有经过非线性变换,浅层在求导的时候就多了一项非指数增长(衰减)的Sl,避免了梯度消失。

举个例子,我们可以考虑一个网络完全由残差连接构成,在反向传播的时候,假设我们希望计算损失关于第一层的一个参数w11的偏导:

∂ L O S S ∂ w 11 = ∂ S 1 ∂ w 11 + ∂ ∑ i = 1 L − 1 F i ( S i ) ∂ w 11 \frac{\partial LOSS}{ \partial w_{11}} = \frac{\partial S_1}{ \partial w_{11}} + \frac{\partial \sum _{i=1} ^{L-1} F_i(S_i)}{ \partial w_{11}} w11LOSS=w11S1+w11i=1L1Fi(Si)

通过残差连接,第一层的信息可以完整传递到输出,第一项就是计算第一层神经层节点关于第一层参数w11的偏导,这样就避免了深层网络求偏导出现梯度消失的问题,这就是残差网络的一个重要作用。

一些文章提到残差连接是一个恒等映射,个人的理解是,一般为了残差连接中输出可以和输入按位相加,非线性变换中会注意保持形状不变,最后就可以直接和原始输入相加,对于变换后输入输出尺寸保持一致的现象,即称为对输入的恒等变换。