Pytorch: ResNet论文学习解析网络结构并pytorch实现

正好课程作业需要用到迁移学习,就顺便学习了很厉害的ResNet网络,是真的厉害呀。

首先给出最具有权威性的论文原文。
论文地址:Deep Residual Learning for Image Recognition

1、ResNet的亮点

1.1、现在的网络层数越来越多,很有可能出现梯度消失和梯度爆炸的问题-----resnet利用了BN(Batch Normalization)方式
1.2、如果只采用扁平的plain结构的网络,层数过高反而精度越低这类的退化问题----resnet采用了residual(残差结构)的网络结构
1.3、模型层数越高,参数越多,计算复杂度越大,优化越困难-----采用residual网络结构能够在不增加参数的情况下更好拟合

2、ResNet网络结构解析

先放出ResNet的网络结构(论文中给出的是34layers):
Pytorch: ResNet论文学习解析网络结构并pytorch实现
我们先来看看为什么要使用残差网络,在传统的网络模型中,都是靠着不同的层数堆叠,但这种方式当层数很深时,会增大训练误差,因此会造成不好优化,论文中给出了传统模型在两个数据集上的结果。
Pytorch: ResNet论文学习解析网络结构并pytorch实现
看到56-layer的网络的误差比20-ayer的误差还要大,所以当层数比较深时,plain网络并不能很好拟合数据,因此提出了新型的残差网络。作者又做了实验来验证网络能够解决退化问题。如下图:
Pytorch: ResNet论文学习解析网络结构并pytorch实现
左侧是扁平网络,34层的错误率比18层的要高,右侧是残差网络,34层的错误率比18层的低了,所以深层的残差网络表现更好。


残差网络的组成结构分别是residual representations(相当于主干short connection(相当于捷径 组成,我们来看看核心组成成分:
Pytorch: ResNet论文学习解析网络结构并pytorch实现
在残差网络中,不是让网络直接拟合原先的映射,而是拟合残差映射。 假设原始的映射为H(x)H(x),残差网络拟合的映射为:F(x)=H(x)xF(x):=H(x)-x。其中F(x)F(x)去拟合残差会比直接利用输入去拟合输出拟合能力更好,并且运算也会比较简单。论文中给出的结论是:

  1. deep residual nets are easy to optimize 【更容易优化】
  2. deep residual nets can easily enjoy accuracy gains from greatly increased depth【能够在深层网络得到更好地精度】

接下来再来看看具体的残差结构,下图是论文给出的两种残差结构:
Pytorch: ResNet论文学习解析网络结构并pytorch实现
为什么深层和浅层用的不一样呢,如果是深层网络,就要尽可能减少参数量。如果同样假设上面左右两幅图输入输出的维度均为256,那么左侧的参数则为33256256+33256256 ≈118w,而右侧的参数为2561164+336464+1164*256 ≈7w,所以深层网络用右侧模型更好!可以看到右侧是实线结构,所以输入和输出的维度一定要一致!,因为在输出节点,残差结构F(X)xF(X)要和输入x相加


既然上面我们提到了输入和输出的维度一定要一致,那如果不一致呢,此时需要的就是虚线结构。
我们先来看看论文中不同模型所对应的不同网络结构,如下图所示:
Pytorch: ResNet论文学习解析网络结构并pytorch实现
可以看到对于34-layer和50-layer,第二层的输出分别是64和256,而第三层的输入均是128,和上一层的输出不匹配,所以要进行相应的升维和降维操作,该操作由1*1的卷积核来实现。
对于低层的网络,根据下图实现维度不匹配问题,通过该层,可以将64维的数据升到128维。
Pytorch: ResNet论文学习解析网络结构并pytorch实现
同理,对于深层网络,利用下图的网络进行升维操作。
Pytorch: ResNet论文学习解析网络结构并pytorch实现
到此,虚线结构和实线结构的网络已经了解了长什么样子,再看看resnet网络和其他网络的区别。
Pytorch: ResNet论文学习解析网络结构并pytorch实现
该图也是以34-layer为例。我们可以发现resne和plain扁平网络对比,也就是增加了shortcut connection结构,而旁边的shortcut有实现和虚线之分,那上面的数字是什么意思呢,回到上面的不同网络比较的表格,可以看到34-layer的层数分别是【3,4,6,3】啦,所以是指该种网络在有多少个。最下面是表现也很优秀的VGG-19网络,但是可以计算一下,VGG所需要的参数非常多,而ResNet的参数确实很少的。所以ResNet是在保证精度的同时还降低了算法复杂度。


再看看作者在后文还提出了变态的resnet1202层网络,如下图:
Pytorch: ResNet论文学习解析网络结构并pytorch实现
在最后一张图还是发现,resnet-1202的错误率比resnet-110高一点,但从前两种图还是可以发现,随着ResNet网络的层数越深,得到的效果越好,更能验证残差网络解决了传统模型存在的退化问题。

3、pytorch利用ResNet进行迁移学习

我进行迁移学习的数据集是CIFAR10,也可以下载其他的数据集,用了TensorFlow官方的花卉数据集试了下,也不错。附上GitHub链接如下。
https://github.com/njau-fyl/Classfication-ResNet