迁移学习 transfer learning

迁移学习 transfer learning

什么是迁移学习?

一句话概括:迁移学习是利用少量数据 (small-scale dataset),二次训练其它任务相关模型,以使其在本任务下表现良好的一种方法。

解释

在训练网络时,往往会出现过拟合 (overfitting) 的情况,其中一个原因是数据量不够。此时,应对 overfitting 的一个方法是正则化 (regularization),另一个方法则是迁移学习 (transfer learning)。

所谓正则化,实际上就是增加模型鲁棒性,防止过拟合的一种方法。常见的有加入L1、L2正则化项、dropout、data augmentation等等。(扯远了…

使用迁移学习的根本原因是数据量不够,导致其训练的模型在训练集上过拟合。此时就要拿出一个在相关领域,已经训练得很好的模型,在它的基础上使用现有的小规模数据集进行二次训练调整,最终达到一个不错的性能。

具体方法是如果你有一个不大的数据集,那么就固定 (freeze) 网络的前 n-1 层,只训练最后一层就好了;如果你很幸运有一个稍微大一点的数据集,那么你可以从最后一层往上,多训练几层以更加提升性能。 迁移学习 transfer learning
针对不同情况,我们可以将其分为四类:(领域相关,领域不太相关) × (数据集不大,数据集有点大)

相应的应对策略如下表所示
迁移学习 transfer learning
也就是说,在相关领域,如果你有一个不大的数据集,那么只训练最后一层就好了;如果你有一个挺大的数据集,那你就可以多精调几层网络。在不太相关的领域,如果你有一个不大的数据集,那就有点惨了,也许需要重新初始化网络权重,尝试在不同阶段 (从不太收敛到收敛这个阶段) 训练最后一层;如果你有一个还不错的数据集,那就可以多训练几层,保证网络提取的信息量。

总结

总结一下,迁移学习打破了训练神经网络需要大规模数据集的神话。你只需要找到一个相关领域训练得还不错的模型,freeze它的前面层,用现有的数据精调网络的后面几层就能达到一个不错的效果了。

PS:相关领域模型的选择最好比你的任务更广泛一些,比如使用YOLO9000模型,而只识别一些车辆、行人等目标。
PS2:前面举的例子是分类领域的,个人认为回归领域的问题应该也差不多 (有待学习)。

参考

cs231n课程第七章迁移学习,观看地址在bilibili