ResNet的skip connection 残差网络的学习

今天我们介绍一下残差网络,学习完前面的知识我们会知道网络的深度对于模型的性能是至关重要的,所以理论上而言模型更深可以取得更好的结果。但是事实真的是这样的吗?实现发现深度网络会出现退化的问题:网络深度增加时,网络准确度出现饱和,甚至会出现下降。这不是过拟合问题,而是由于深层网络中存在梯度消失和梯度爆炸的问题,这使得深度学习模型很难训练, ResNet的skip connection就是为了解决梯度消失这个问题,skip connection则能在后传过程中更好地把梯度传到更浅的层次中。梯度消失问题:在反向传播的时候,随着传播深度的增加,梯度的幅度会急剧减小(试想一下,在链式求导法则中,小于1的数连续相乘,不是会变得更小吗?如果层数很多,一下子就会接近0)残差网络的定义是这样的:当输入是x时,学习到的特征标记为H(x),现在我们希望其可以学习到残差F(x) = H(x)-x,这样原来的学习特征就成为了F(x)+x,当残差为0时,此时堆积层就是做了一个恒等映射,即H(x) = x,这样至少保证到我们的网络性能不会因为反向传播的梯度问题而退化。实际上残差是不会为0的,这样保证堆积层在输入特征基础上学习到新的特征从而拥有更好的性能。下面是残差网络的示意图:
ResNet的skip connection 残差网络的学习

为什么残差学习相对更容易,我们可以从数学的角度去分析这个问题,首先我们定义残差单元:
ResNet的skip connection 残差网络的学习

Xl和Xl+1表示的是第l个残差单元的输入和输出,F是残差结构,表示学习到的残差,当h(xl)=xl时表示的就是恒等映射,f是relu**函数,我们从浅层l到深层L的学习特征为:
ResNet的skip connection 残差网络的学习

反向传播过程为:

ResNet的skip connection 残差网络的学习
可以看到,对于任何一层的x的梯度由两部分组成,其中一部分直接就由L层不加任何衰减和改变的直接传导l层,这保证了梯度传播的有效性;另一部分也由链式法则的累乘变为了累加,这样有更好的稳定性,括号中的1可以保证该式子的值不会小于1,在梯度的链式求导中解决了梯度弥散的问题。下面我们再分析一下残差单元,ResNet使用两种残差单元,如下图所示。左图对应的是浅层网络,而右图对应的是深层网络。对于短路连接,当输入和输出维度一致时,可以直接将输入加到输出上。但是当维度不一致时(对应的是维度增加一倍),这就不能直接相加。有两种策略:(1)采用zero-padding增加维度,此时一般要先做一个downsamp,可以采用strde=2的pooling,这样不会增加参数;(2)采用新的映射(projection shortcut),一般采用1x1的卷积,这样会增加参数,也会增加计算量:
ResNet的skip connection 残差网络的学习

总结:今天我们对残差网络进行了一个了解,用数学知识展示了其原理,残差网络在现在的深度学习中应用非常广,Resnet可以说是解决了深度网络的瓶颈。在下一节的内容中,我们将使用tensorflow实现一个resnet网络

创作不易,点赞支持一下呗
ResNet的skip connection 残差网络的学习