Incremental Learning of Object Detectors without Catastrophic Forgetting详解
Incremental Learning of Object Detectors without Catastrophic Forgetting详解
最近由于项目的需要在研究incremental learning在目标检测方面的应用,刚好读到了INRIA在2007年的一篇paper,采用蒸馏loss的方法来做incremental learning的,所以写这篇博客记录下来。
概述
不懂什么叫incremental learning或者是catastrophic forgetting的可以参考知乎这个链接,王乃岩介绍的非常完善,自己也学到了不少。
CNN用于目标检测任务的缺陷——类别遗忘:假设CNN模型A为在一个物体检测训练集1上训练得到的性能较好的检测器,现在有另外一个训练集2,其中物体类别与1不同,使用训练集2在A的基础上进行fine-tune得到模型B,模型B在训练集2中的类别上可以达到比较好的检测结果,但是在训练集1中的类别上检测性能就会大幅度下降;
本文目的:缓解CNN用于目标检测任务的类别遗忘,在训练集1中原始图片不可得以及新图片中不包含训练集1中存在的类别的标注的情况下,在训练集2上fine-tune模型A得到模型B,可以同时在训练集1和2中的类别上获得较好的检测性能;
本文核心:在fine-tune模型A得到模型B的过程中提出一个新的损失函数,用于同时考虑网络在新的类别上的预测性能以及原始类别在新模型B和旧模型A上的响应差异,LOSS=新类别检测LOSS+旧类别在模型A和模型B上的差异LOSS
网络结构
作者也提出:解决这个新增分类的问题可以再模型A上增加对新类别的预测分支,随即初始化该分支后,用新类别数据fine-tune这个分支,但是这样做会导致一个问题,此时得到的网络对原来N个类别的检测性能会大幅下降。所以作者提出了一种新的loss,既能够检测出新的类,同时也能保证在旧的类的检测准确率不会下降。网络结构如下:
Network A:It contains a frozen copy of the original detector。作用:1)检测原始类别的bbox;2)蒸馏proposals并计算蒸馏loss;
Network B:用于对新增分类B的网路,结合模型A最终可以预测出新的类和旧的类;
作者指出:选择fast-rcnn而非选择faster-rcnn,因为faster-rcnn中有RPN层,其对类别有一定的敏感性,因为RPN可被训练且共享卷积,,不利于最后蒸馏loss的计算,所以作者选基于edgeboxes的fast-rcnn,因为其类别对proposal不敏感。
在作者的这个fast-rcnn中,将vgg16替换为resnet50,并在最后一层stride!= 1的卷积层前加入了RoI pooling层,然后在接上剩下的卷积层和两层FC连接每个类别的得分输出和回归输出,使用该主干网络训练用于检测类别集合1的模型A。
loss_cls层评估分类代价。由真实分类u对应的概率决定:
=−log
=−log
loss_bbox评估检测框定位代价。比较真实分类对应的预测参数和真实平移缩放参数为的差别:
=g(−)
g为Smooth L1误差,对outlier不敏感:
总代价为两者加权和,如果分类为背景则不考虑定位代价:
这个详细的可以参考fast-rcnn原paper,这里不详说。
训练方法
首先训练一个fast-rcnn的网络结构使其能够检测原本的数据集,这个网络结构记为A()。所以我们现在的目标是曾杰一个新的类数据集。
我们对先前训练得到的网络A()做两份copies:一个冻结的网络通过蒸馏loss对原来的进行检测识别;另外一个B()被扩充用来检测新的分类(在元数据中未出现或未被标注)。我们创建一个新的FC层用来只对新的分类检测,然后将其output和原来的的输出做concat,即:根据新增加的类别数对网络A进行扩展,即增加全连接层的输出个数,得到初始化的Network B网络。新的层是采用和先前的网络A一样的初始化方式进行随机初始化的。现在我们的目标就是:训练一个网络能够仅仅使用新的数据,最后能够识别出新增分类和旧分类的网络。
作者指出蒸馏loss是为了“keeping all the answers of the network the same or as close as possible”。如果我们训练网络B()不做蒸馏的话,这个网络的性能在原来的类上将会急剧下降,这就是所谓的catastrophic forgetting(灾难性遗忘)。Even if no object is detected by A(CA), the unnormalized logits (softmax input) carry enough information to “distill” the knowledge of the old classes from A() to B().
细节
对于每一个训练图片,随机从128个RoI中选取64的背景得分最低的RoI,并分别得到其通过模型A后在旧的类别集合上的得分和回归目标,同样得到其在通过模型B后在旧的类别集合上的得分和回归目标。
Loss函数包括logits(即softmax的input)和回归的outputs:
N:用于蒸馏的RoIs的数量(文章选的64)
||:原始数据的类别数
:bounding box regression outputs
蒸馏logits不使用任何的smoothing,因为大多数的proposals已经经历了smoothing在分数的分布上。在我们的试验中,在初始阶段,新的和旧的网络的参数基本一致,所以没必要smoothing来稳定其训练。
所以总的损失函数定义如下:
采样策略
作者实验发现:选择非背景proposal进行蒸馏学习相比随机选择proposal进行蒸馏学习得到的网络更检测性能更好。
其他的作者做的一些实验本文就不在这里叙述了。随后献上paper和作者的代码。