贝叶斯角度对MAML的新的理解

本文依据文章title:
贝叶斯角度对MAML的新的理解
本文按照文章顺序进行,只抽取笔者认为的重点部分,如有不妥,还请看客给出意见建议一起进步。感兴趣的话,建议去biying原文阅读。


2. MAML以及分层贝叶斯表示的回顾

  • 2.1 元学习作为基于梯度的超参数优化器
    参数化的元学习目的是find一些shared参数,当面临novel task时,能够轻松的find适用于该task的参数。
    MAML提供了一中基于梯度的元学习过程,它使用相同元学习率对各个任务得到其一步梯度后的参数,实现fast adaptation。
    将MAML的学习目标表示如下:
    贝叶斯角度对MAML的新的理解
    可以看出,它在内循环使用多任务的参数更新,在外循环使用这些参数来计算相同任务不同采样的损失函数值,并对原始参数进行一次更新。通过这种方式,可以充分利用多任务中的梯度信息,从而有希望学到common的知识,成为下一次参数更新的先验。这其中很自然的蕴含了在线学习的思想,这也是为什么iclr2018的best paper基于MAML做了增强学习下复杂环境下的continue learning—实时动态。
  • 2.2 元学习作为分层的贝叶斯推理
    我们首先将MAML的参数更新过程表示成下面左图所示,并给出其概率图模型,从下面的分析中可以看出二者的联系。
    贝叶斯角度对MAML的新的理解
    左图反应的实际就是上面(1)式描述的事情。

要强调的一点是(1)式里的条件概率,对于我们的代价,可以很自然的将其表示为条件概率的形式,因为我们的最终输出是在给定模型参数下样本属于各类的概率。

这样就很好理解图1的左边部分了。右边部分实际上是左边计算图的一个概率图模型表示。具体来说就是,在某一次更新外循环参数时,固定为theta,此时,对于各个任务来说,theta作为模型参数的先验,每个任务的phi从theta中采样得到,满足iid条件,但是这个采样并不是随意的,它有一个目标就是能够根据该phi以高概率将该任务中的N样本xjn分类准确,即条件概率p(xjn|phi)大。而这个phi就是theta在各任务上的最大后验估计(MAP)。

要强调的一点是图1的左图是计算图,右图是概率图,概率图反应出来的仅仅是计算图的phi的后验是如何得到的这一步。之后的对先验theta的更新就和PAC-bayes里的可变先验的内容有点相似了,而参数以及样本的生成关系已经在概率图中全部表示出来了。

贝叶斯角度对MAML的新的理解
这部分原文主要就告诉我们基于MAML理解的分层的贝叶斯在多任务上是如何更新模型参数theta的。(2)式和(1)式在形式上虽然不同,但是做的事情是一样的。(2)式是(1)式的抽象表示,(1)式是(2)式在MAML框架下的具体实现。我们不妨坐下对比,可以发现:
(2)式中观测关于phi的条件概率对应于(1)式中的观测关于任务j更新后的参数的条件概率;(2)式中phi关于theta的条件概率对应于(1)式中最内层循环更新第j个任务的参数phi。

3.基于梯度的元学习和分层贝叶斯间的LINK
这部分将2.1和2.2两部分结合起来,提出了基于贝叶斯分层推理的MAML。并证明MAML的内循环更新任务参数phi对应贝叶斯推理中的先验知识的更新,通过改进该先验,进一步提高了MAML在多任务小样本场景下的识别性能。

  • 3.1 MAML作为empirical bayes(这里就是指分层的贝叶斯,原话见icml1998年的关于多任务和分层贝叶斯的论文,作者是Tom Heskes
    (2)式中phi关于theta的条件概率通常很难handle,于是,我们采用theta关于phi的点估计(MLE)重写(2)式如下:
    贝叶斯角度对MAML的新的理解
    其中,phi_hat表示任务j的phi的点估计。具体来说,如下式:
    贝叶斯角度对MAML的新的理解
    可以看出,MAML的更新方式实际上等价于目标函数关于元学习参数theta来最大化观测样本的边缘似然p(x|theta)。实际上,这个点估计可以通过任务j的数据采样利用一步或若干步梯度方便的获得。
    从贝叶斯角度理解MAML,它相当于通过若干步梯度计算出关于任务phi的后验,用这个后验来表示该任务从初始参数theta,对观测数据的一个较可信的估计,是一种在先验和任务上做的trade-off。有点类似于牛顿法中,在初始值附近迭代的寻找局部最优解的过程,唯一的不同在于,最后会根据这些多个局部最优解更新模型的先验theta,这实际上是一个变先验让模型对新事物有持续学习的能力的过程,而学习新事物所用到的元知识,就是根据经验(历史任务)学习到的初始化参数(点估计)。

    接下来,我们以线性回归为例,进一步理解上面提到的trade-off。
    贝叶斯角度对MAML的新的理解
    上式的目标函数是现行回归的典型表示。我们在sgd下设置步长参数来更新模型参数phi。
    贝叶斯角度对MAML的新的理解
    santos证明,当模型参数以theta为初始化时,在现行回归问题中,我们的目标函数变成(5)式,加上了初始参数和模型参数的Q范数约束。
    贝叶斯角度对MAML的新的理解
    最小化(5)式我们又可以将它写成最大化(6)中参数phi的边缘似然。注意到(5)式中的两个求和项分别对应着(6)式中的两个高斯分布。这通过简单的推导可以看出来。
    现在再对比(4)式和(6)式,可以看出,在theta的领域内,关于phi最小化(4)式的结果等价于最大化(6)式。从而,我们得出结论:给观测样本一个带噪的高斯分布以及令phi服从theta领域内的高斯分布的条件下,通过(4)式经过k步梯度下降求解出来的phi就是phi的MAP解。

在线性回归中,对某一次任务而言,用MAML对其参数进行更新的结果就等价于用empirical bayes 对phi做的MAP的点估计的结果,此时得到的phi是全局最优的。
与非线性问题对应,在这种场景下MAML可以有同样的解释,唯一的区别在于,此时的phi不是全局最优而是局部最优。

这部分的最后一小块,我们来简要回顾一下对于线性任务下早停和高斯先验的关系以及在非线性任务下对参数初始化这种implicit的正则和高斯先验的关系。任何一种以截断梯度更新参数的方式都隐含着对参数后验分布的MAP点估计,对于线性问题,这个估计对应全局最优解,对于非线性问题,这个估计对应局部最优解。

最后给出MAML的贝叶斯推导下的目标函数,
贝叶斯角度对MAML的新的理解
这启发我们,可以用其他的meta-optimization来估计phi的后验,从而可以进一步改进算法。这也是本文的motivation的由来!

下面,给出MAML贝叶斯理解下的算法框图,
贝叶斯角度对MAML的新的理解

  • 3.2 基于任务的参数的先验
    由3.1我们以二次目标函数为例,已经知道关于phi通过早停得到的fast adaptation的结果与给定初始化theta下phi的先验选择相对应。从梯度的角度理解,phi的更新仅用到了一阶的信息,现在,我们考虑关于phi的二阶近似估计。
    贝叶斯角度对MAML的新的理解
    我们的目标是要在phi_star的领域内找到最优的phi。(7)式相当于是对目标函数L在phi_star处的二阶泰勒展开,与牛顿法的更新方式的表示一样。
    进一步给出(7)式的参数更新公式,
    贝叶斯角度对MAML的新的理解
    这里用curvature矩阵B来近似Hessian矩阵的逆,到目前为止,关于meta-optimization的方式与牛顿法完全一致。可以说,本文的思想就是用牛顿法来做新的meta-optimization。但是,给它穿上了bayes的衣服,使其更加丰满了!

    现在为了最小化(7)式,在给定初始值phi_0的情况下,等价于最小化下面的式子:
    贝叶斯角度对MAML的新的理解
    文章中有一些关于(9)式的讨论,感兴趣的可以自己再读下文章的这部分章节。当目标函数是二次时,(9)式退化成为(5)式。

4.提升MAML的性能

  • 4.1 拉普拉斯推理方法
    在开始之前,有必要对拉普拉斯近似和MAP做点估计的方法做一比较,给出下面的博客链接,比较详细:
    贝叶斯推断之拉普拉斯近似

需要注意的是,拉普拉斯近似也是假设随机变量服从高斯分布,并求解其充分统计量对随机变量进行建模的。同时,它会利用牛顿法求出的参数的MAP点估计结果。

再回到文章,我们来看看作者引入拉普拉斯近似的原因。由上面我们知道,MAML对参数的后验做点估计。考虑到有可能phi关于theta的条件概率不是delta函数,此时MAML得到的点估计就会存在偏差。为了克服这个问题,一种方式就是对参数的分布建立一个在局部最优点附近的高斯分布,不仅估计均值(点估计),也估计方差,并从这个分布中采样参数,对参数进行平均来得到模型参数的后验估计。以此来降低估计的参数的偏差。而拉普拉斯近似正是具有我们希望的性质的这样一种求解技术!这里再列出文章中的原话,以防由于笔者理解的偏差带偏听众:

贝叶斯角度对MAML的新的理解

注意文中原文的意思是,会形成一个局部二次近似!

有了以上的认知,我们就可以在MAML对基于任务的phi(fast adaptation)的更新上融入这种uncertainty。
对(2)式,我们假设其有一个well的点估计phi_star,我们在其领域内对其做泰勒二阶展开,
贝叶斯角度对MAML的新的理解
其中,Hj是第j个任务下代价函数(似然函数)对参数phi的Hessian矩阵。拉普拉斯近似利用MAP的点估计phi_star作为局部最优,并在该解附近引入
方差构建了关于局部最优解的二次近似,使其对最优解的扰动有一定容忍度,这实际上进一步放宽了MAML的原始假设—p(phi|theta)在最优解附近满足delta函数。将上式代入(2)式,得到用拉普拉斯近似的MAML的目标函数:

贝叶斯角度对MAML的新的理解
这一小节的最后, 我给出文章中关于(11)式的一些说明,结合上面的介绍和讲解,很好理解。
贝叶斯角度对MAML的新的理解

  • 4.2 用曲率信息提升MAML性能
    文章分析了求解4.1中代价函数的困难,并提出用近似的fisher信息矩阵来求解(11)式中的Hessian矩阵的行列式。
    下面是fast adaptation基于任务的参数phi的算法:
    贝叶斯角度对MAML的新的理解
    贝叶斯角度对MAML的新的理解

MAML作为一个fundmental的方法,一经提出,就已经有至少三篇顶会基于该方法做了非常有趣的工作!希望以后能够多多出这样的算法,更希望看到我们国人做出这类fundmental的工作!!!共勉!加油!