《Structured Knowledge Distillation for Dense Prediction》论文笔记

代码地址:structure_knowledge_distillation

1. 概述

导读:这篇文章针对的是密集预测的网络场景(如语义分割),在之前的一些工作中对于这一类网络的蒸馏时照搬分类任务中那种逐像素点的蒸馏方式(相当于是对每个像素的信息分别进行蒸馏),文章指出这样的产生的结果并不是最优的(这样策略会忽视特征图里面的结构信息,像素信息之间是存在关联的),因而这篇文章提出了适应密集预测网络的蒸馏策略:1)pair-weise蒸馏:通过构建静态图(受pair-wise的马尔可夫随机场启发,增强特征图中sptial上的的相关性,使得可以学习到结构性信息)蒸馏对应的相似性;2)holistic蒸馏:使用对抗训练去蒸馏学生网络和教师网络的输出信息(这里使用了更加高维度的信息,目的是使得判别器无法判断信息的来源)。文章的方法在语义分割,深度估计与目标检测上进行了实验,其结果也显示了其有效性。

文章从传统的蒸馏方法开始进行分析,由于传统的蒸馏方法是逐像素点的蒸馏方式,对特征中的结构性信息并没有很好提取,对此文章针对性的给出两种蒸馏的策略:

  • 1)pair-wise方式的蒸馏:文章使用pair-wise的马尔科夫随机场框架来增强空间labelling的连续性,目标是对齐简单网络(student)和复杂网络(teacher)中学到的pair-wise特征,从而使得学生网络能够学习到更多的结构信息;
  • 2)holistic蒸馏:这里并不将知识迁移的维度限定在pair-wise与pixel-wise上,而是使用对抗训练的形式监督学生与教师网络的输出,使其在更高的维度上进行近似逼近。判别器考虑的是网络输入图像(作为条件输入)与网络输出组成的holistic embedding,使得学生网络生成的结果不断近似教师网络。

使用文章的方法进行蒸馏,其在相应的baseline上得到的性能比较见下图所示:
《Structured Knowledge Distillation for Dense Prediction》论文笔记

2. 方法设计

2.1 蒸馏的整体结构

对于分割任务其流程是:对于个3通道的输入图像IRWH3I\in R^{W*H*3},它在经过卷积网络特征抽取之后,得到特征图FRWHNF \in R^{W^{‘}*H^{‘}*N}的特征图(论文代码中给出的stride=16),之后对其使用分类得到分类类别为C的结果QRWHCQ \in R^{W^{‘}*H^{‘}*C},之后将其上采样与原始输入图像尺寸保持一致。因而对于像分割这类的密集预测问题,文章设计了图2的蒸馏结构:
《Structured Knowledge Distillation for Dense Prediction》论文笔记
在上图中总共设计了3种类型的蒸馏损失:pair-wise的蒸馏损失,pixel-wise的蒸馏损失,Wasserstein距离损失(学生网络还有分割交叉墒损失)

2.2 Pixel-wise蒸馏

文中使用标记SS代表学生网络,TT代表教师网络。对于分割部分特征图QQ其使用的是原始的蒸馏方式,使用KL散度计算差异,因而这部分的损失函数描述为:
Lpi(S)=1WHiR(pis,pit)L_{pi}(S)=\frac{1}{W^{‘}*H^{‘}}\sum_{i\in R}(p_i^s,p_i^t)
其中,pis,pitp_i^s,p_i^t代表来自学生网络与教师网络的概率值,R={1,2,,WH}R=\{1,2,\dots,W^{‘}*H^{‘}\}代表特征图上像素。

2.3 Pair-wise蒸馏

pair-wise的马尔可夫随机场被广泛用于增强spatial维度的连续性,因而这篇文章中利用了其在spatial维度的pair-wise相关性进行知识蒸馏。

文章中是通过构建局部相关图的形式进行知识迁移,在文章定义好的图中包含了很多node,这些node代表的是spatial上的不同位置与两个相邻node之间的相关性。对此文章首先定义了两个超参数:计算相关性窗口的大小α\alpha以及计算时候的粒度大小β\beta 。从而由这两个参数所组成的一个局部相关性node见下图3所示:
《Structured Knowledge Distillation for Dense Prediction》论文笔记
因而对于一个特征图HWCH^{‘}*W^{‘}*C,其在spatial维度上包含了HWβ\frac{H^{‘}*W^{‘}}{\beta}个node信息,每个node信息HWβα\frac{H^{‘}*W^{‘}}{\beta}*\alpha个spatial上的node连接(论文代码采用的α\alpha默认为最大α=(HWβ)2\alpha=(\frac{H^{‘}*W^{‘}}{\beta})^2)。

使用aijt,aijsa_{ij}^t,a_{ij}^s分别代表教师与学生网络在第iinode与计算窗口内第jjnode的相关性。因而pair-wise的损失函数就被定义为:
Lpa(S)=βHWαiRjα(aijsaijt)2L_{pa}(S)=\frac{\beta}{H^{‘}*W^{‘}*\alpha}\sum_{i\in R^{‘}}\sum_{j\in \alpha}(a_{ij}^s-a_{ij}^t)^2
R={1,2,,HWβ}R^{‘}=\{1,2,\dots,\frac{H^{‘}*W^{‘}}{\beta}\}。在计算node信息的时候遇到α\alpha不为1的情况,则对应的特征图切片之后为βC\beta*C,之后在第一个维度上shying均值池化得到1C1*C的特征,之后在将这些特征使用下面的距离度量方式进行计算:
aij=fiTfjfi2fj2a_{ij}=\frac{f_i^Tf_j}{||f_i||_2 ||f_j||_2}

2.4 Holistic蒸馏

在前面的内容中从pixel-wise和pair-wise两个维度上进行蒸馏,这篇文章中还引入了更加高级的特征蒸馏方式,即是引入GAN网络的思路,使得学生网络的输出趋向于教师网络。

文章中将学生网络看作是带有条件的生成器(条件为输入的RGB图像II),得到对应的特征图QsQ^s,在对抗网络中扮演假值。而教师网络的输出为QtQ^t,在对抗网络中扮演真值。这两个数据经过文章设计的判别器网络D()D(\cdot)(由卷积,self-attention模块组成的评价网络)产生对应的embedding特征,之后使用Wasserstein距离度量两个embedding之后分布(ps(Qs)p_s(Q^s)pt(Qt)p_t(Q^t))的距离,通过梯度反传逐步实现学生网络向教师网络输出的逼近。因而这部分的GAN损失可以描述为:
Lho(S,D)=EQsps(Qs)[D(QsI)]EQtpt(Qt)[D(QtI)]L_{ho}(S,D)=E_{Q^s-p_s(Q^s)}[D(Q^s|I)]-E_{Q^t-p_t(Q^t)}[D(Q^t|I)]

2.5 网络优化过程

网络的损失除了前面提到的三个蒸馏损失之外,还有一个学生网络的分割交叉墒损失,可以描述为:
L(S,D)=Lmc(S)+λ1(Lpi(S)+Lpa(S))λ2Lho(S,D)L(S,D)=L_{mc}(S)+\lambda_1(L_{pi}(S)+L_{pa}(S))-\lambda_2L_{ho}(S,D)
其中,λ1=10,λ2=0.1\lambda_1=10,\lambda_2=0.1,文中指出学生网络SS与判别器DD是交叉进行训练的,其训练的步骤可以归纳为:

  • 1)训篇判别器,固定前面的网络,训练判别器就是最小化之前起到的判别器损失Lho(S,D)L_{ho}(S,D),也就是判别embedding特征的来源;
  • 2)训练学生网络SS,这里固定判别器网络DD,目标是最小化交叉墒损失与蒸馏损失,可以描述为:
    Lmc(S)+λ1(Lpi(S)+Lpa(S))λ2Lhos(S)L_{mc}(S)+\lambda_1(L_{pi}(S)+L_{pa}(S))-\lambda_2L_{ho}^s(S)
    其中,Lhos(S)=EQsps(Qs)[D(QsI)]L_{ho}^s(S)=E_{Q^s-p_s(Q^s)}[D(Q^s|I)],期望学生网络的生成结果能到判别器中得到较高的分值。

3. 实验结构

文章几个蒸馏损失对于最后性能的影响:
《Structured Knowledge Distillation for Dense Prediction》论文笔记
文章中α,β\alpha,\beta两个超参数与最后结果的影响:
《Structured Knowledge Distillation for Dense Prediction》论文笔记

安利时间:

PyTorch从入门到实战一次学会