元学习(meta learning)入门笔记--MAML

1、meta learning

元学习(meta learning)入门笔记--MAML

对于经典的深度学习方法,我们是通过人为定义的网络结构,人工设计的参数初始化方法,还有人工设计的梯度更新策略,来逐步更新函数f的参数,最终找到一个对当前数据较好的函数f,如下:

元学习(meta learning)入门笔记--MAML

对于元学习方法,我们希望能够替换这些人工设计的部分,让机器自己学会什么样的设置对训练任务是有利的,即找到一个函数F,它能根据数据,给出一个比较好的函数f,这也就是“学会去学习”。于是,对于元学习而言,我们可以定义一系列的learning algorithm,比如不同的网络参数设置就对应不同的learning algorithm,或者说不同的梯度更新方法对应不同的learning algorithm(这里,不同的定义对应不同的元学习方法,比如MAML就是学习一套好的网络初始化参数,还有一些方法是直接动态预测网络模型的参数)。接下来的问题是如何评估这些不同learning algorithm的好坏?对应到机器学习中,我们是通过定义loss function,然后根据样本loss来判断f的好坏,同样的道理,在元学习中,我们在更高的层次去评估,也就是可以定义不同的task,根据learning algorithm在不同task上的表现来判断:

元学习(meta learning)入门笔记--MAML

​这里,不同的task包含了完整的train set和test set,比如,可以认为是不同识别任务:

元学习(meta learning)入门笔记--MAML

​还有个小点是,一般如果不同task包含的train和test set过大的话,这个训练任务过于庞大费时,因此一般都假设使用到的train 和test set样本不多,这也就与few shot方法经常联系在一起了。N-way K-shot是few-shot learning中常见的实验设置。N-way指训练数据中有N个类别,K-shot指每个类别下有K个被标记数据。

2、MAML

目的:学习一套好的网络初始化参数

损失:提供一套初始化参数,然后让网络在不同task上去训练,最终得到属于不同task的网络参数,然后评估这些不同的网络参数在各自task上的测试集里表现如何,如此就可以得知最初的这套初始化参数到底怎么样。

元学习(meta learning)入门笔记--MAML

​实际实现时,为了简化训练,我们定义网络在不同task上只进行一次梯度更新训练后得到的参数就是属于不同task的最终网络参数,其实这就相当于对初始化参数进行一次梯度更新:

元学习(meta learning)入门笔记--MAML

​接下来需要求meta learning总的评估函数F对网络初始参数元学习(meta learning)入门笔记--MAML的梯度计算:

元学习(meta learning)入门笔记--MAML

元学习(meta learning)入门笔记--MAML

​这里为了简化计算,进一步进行一阶近似:

元学习(meta learning)入门笔记--MAML

元学习(meta learning)入门笔记--MAML

讲的更加形象一点,具体实现时,梯度更新方法如下:

元学习(meta learning)入门笔记--MAML

​也就是,task m从????0更新到????m就是task m上的最终結果(1次),但我们还是故意再更新一次,也就是接着计算????m的梯度(即上图中第二根绿色箭头),将这一梯度乘以学习率赋给????0,得到????0的梯度更新结果????1,如此往复。我们再来看MAML论文中的算法流程,就好理解多了,表示如下:​

元学习(meta learning)入门笔记--MAML

1、我们用于训练的模型架构是元学习(meta learning)入门笔记--MAML(假设初始化参数为元学习(meta learning)入门笔记--MAML​),这可能是一个输出节点为5的CNN,训练的目的是为了使得模型有较优秀的初始化参数。最终我们想要学出可以用于数据集元学习(meta learning)入门笔记--MAML分类的模型是元学习(meta learning)入门笔记--MAML​,元学习(meta learning)入门笔记--MAML​ 和 元学习(meta learning)入门笔记--MAML的结构是一模一样的,不同的是模型参数。

2、我们将1个任务task的support set去训练元学习(meta learning)入门笔记--MAML​ ,这里进行第一种梯度下降,假设每个任务只进行一次梯度下降,也就是元学习(meta learning)入门笔记--MAML。那么执行第2个task训练时,有 ​元学习(meta learning)入门笔记--MAML

3、上述步骤2用了batch size个task对元学习(meta learning)入门笔记--MAML进行了训练,然后我们使用上述batch个task中地query set去测试参数为​元学习(meta learning)入门笔记--MAML元学习(meta learning)入门笔记--MAML模型效果,获得总损失函数元学习(meta learning)入门笔记--MAML,这个损失函数就是一个batch task中每个task的query set在各自参数为元学习(meta learning)入门笔记--MAML元学习(meta learning)入门笔记--MAML中的损失之和。

4、获得总损失函数后,我们就要对其进行第二种的梯度下降。即更新初始化参数元学习(meta learning)入门笔记--MAML,也就是元学习(meta learning)入门笔记--MAML来更新初始化参数。这样不断地从步骤2开始训练,最终能够在数据集上获得该模型比较好的初始化参数。

5、根据这个初始化的参数以及该模型,我们用数据集元学习(meta learning)入门笔记--MAML的support set对模型进行微调,这时候的梯度下降步数可以设置更多一点,不像训练时候(在第一次梯度下降过程中)只进行一步梯度下降。

6、最后微调结束后,使用元学习(meta learning)入门笔记--MAML的query set进行模型的评估。

参考:

【1】https://www.bilibili.com/video/av46561029/?p=41

【2】https://zhuanlan.zhihu.com/p/181709693

【3】https://zhuanlan.zhihu.com/p/66926599

【4】https://blog.csdn.net/shaoyue1234/article/details/102400044