2020李宏毅学习笔记——34.Network Compression(3_6)

3.Knowledge Distillation(知识蒸馏)

整个知识蒸馏过程中会用到两个模型:大模型(Teacher Net)和小模型(Student Net)。

3.1 具体方法
先用大模型在数据集上学习到收敛,并且这个大模型要学的还不错,因为后面我们要用大模型当老师来教小模型学习嘛,如果大模型本身都没学好还教个锤子,对吧?1和7长得蛮像的。所以这里的损失函数用的是交叉熵,不能用简单的平方差之类的。

3.2 举例子
我们以MNIST数据集为例,假设大模型训练好了,现在对于一张数字为“1”的图像,大模型的输出结果是由0.7的概率是1,0.2的概率是7,0.1的概率是9,这是不是有一定的道理?相比如传统的one-hot格式的label信息,这样的label包含更多的信息,所以Student Net要做的事情就是对于这张数字为“1”的图像,它的输出结果也要尽量接近Teacher Net的预测结果。

为什么跟着老师学学的好呢 ?因为老师不知会交给你这是1,还会告诉你1和7很像。
2020李宏毅学习笔记——34.Network Compression(3_6)
当然,一个更骚气的办法就是让多个老师出谋划策来教学生,即用Ensemble Net来进一步提升预测准确率,让学生学习的知识更加准确。
2020李宏毅学习笔记——34.Network Compression(3_6)
那Student Net到底如何学习呢?首先回顾一下在多类别分类任务中,我们用到的是softmax来计算最终的概率,即yi。一般做cross entro最后都有一个softmax。但是这样有一个缺点,因为使用了指数函数,如果在使用softmax之前的预测值是x1=100,x2=10,x3=1,那么使用softmax之后三者对应的概率接近于y1=1,y2=0,y3=0,会发现这样小模型根本没有学到另外两个分类的信息。那这和常规的label无异(只是学到了这是一,没有1和7像)了,所以为了解决这个问题就引入了一个新的参数T,称之为Temperature,即有:后面的yi。此时,如果我们令T=100,那么最后的预测概率是y1=0.56,y2=0.23,y3=0.21。发现我们通过T把y的差距变小了,导致各个分类都有几率,小模型学习的信息就丰富了
T是超参数。(不过李宏毅老师在视频里提到说这个方法在实际使用时貌似用处不大)