元学习—关系网络和匹配网络

元学习—关系网络和匹配网络

1 关系网络(Relation Network)

1.1 关系网络的基本函数

一个关系网络至少需要包含两个核心的函数,第一个核心的函数是编码函数f,该函数经支持集(Support Set)和查询数据(Query data)进行编码,用于后续的数据计算。其次,一个关系网络至少还需要包括一个关系函数g,而关系函数的作用则为根据查询数据和支持集中各个分类的相关性来决定查询点的最终分类。

下面,我们以一个one-shot的图片分类学习过程为例,即support set中各个分类的样本数量为1。此时Support集中的内容为{Lion,Elephant,Dog}\{Lion,Elephant,Dog\}

1.1.1 编码函数 f

对于编码函数f而言,其具体的方法有很多种,如对于图片信息抽取的CNN算法,对于文本信息抽取的RNN类方法。其计算过程为:
xi{Support set}, f(xi)=>Exix_i∈\{Support\ set\},\ f(x_i)=>E_{x_i}
xj{Query set}, f(xj)=>Exjx_j∈\{Query\ set\},\ f(x_j)=>E_{x_j}
其中xjx_j是查询样本,xix_i为Support中的一个样本,由于Support中每一个分类只有一个样本向量,那么xix_i即代表着一个分类的向量信息。

1.1.2 相似性计算

在进行编码之后,下一步就是将查询点的编码向量和各个分类对应的编码向量进行相似性计算,在之前我们提到过,在Support中每一个分类下面只有一个样本向量,那么我们上面计算出来的每一个xix_i的编码就可以代表一个分类的向量。

在计算完各个分类的向量之后,下一步就是相似性的计算。我们假设计算函数为Z,则计算过程如下所示:
Z(f(xi),f(xj))Z(f(x_i),f(x_j))
其中xjx_j为查询样本,xix_i为Support中的样本。其中Z(A,B)Z(A,B)表示将A,B进行拼接,在根据Z函数进行

最后,根据相似性函数Z的计算,我们可以获取查询样本xjx_j和每一个分类xix_i的相似性情况。进一步,我们使用函数g来将相似性转换成概率分值,最后选择概率分值最大的分类作为预测分类。其计算过程如下:
rij=g(Z(f(xi),f(xj)))r_{ij} = g(Z(f(x_i),f(x_j)))

1.2 总体流程

我们下面以一张图来总结上述的流程:

元学习—关系网络和匹配网络

1.3 few-shot学习和zero-shot学习

1.3.1 few-shot学习

与上述one-shot相比,few-shot在支持集中的样本数量要更多,所以,我们在处理各个分类的编码向量的时候要复杂一些。具体的方法有很多种,这里我们给出一个最简单的计算思路,即将Support Set中各个分类对应的各个样本向量进行求和的操作。此时的具体计算流程为:
xi{Support set}, f(xi)=>Exix_i∈\{Support\ set\},\ f(x_i)=>E_{x_i}
xj{Query set}, f(xj)=>Exjx_j∈\{Query\ set\},\ f(x_j)=>E_{x_j}
ExC=j=imExjE_{x_C}=∑_{j=i}^mE_{x_j}
其中C表示分类,i和m表示在i到m之间的样本均属于C类。
Z(Exc,Exj)Z(E_{x_c},E_{x_j})
rcj=g(Z(Exc,Exj))r_{cj}=g(Z(E_{x_c},E_{x_j}))
最终,其整体的计算过程如下图所示:

元学习—关系网络和匹配网络

1.3.2 zero-shot学习

对于Support Set中不存在样本的情况,我们需要各个分类的“元信息”,即各个分类的语义向量。此时,我们选择使用两个编码函数f1f_1f2f_2来分别对元信息和查询样本进行编码,计算过程如下:
f1(Vc), f2(xj)f_1(V_c),\ f_2(x_j)
其中VcV_c表示第c个分类的元信息(语义向量)。后续Z,g的计算过程与上述类似,这里不再赘述。

2 匹配网络(Matching Network)

2.1 基本计算过程

匹配网络也是一个简单有效的one-shot的学习方法。其能够产生不在训练集中出现的类别标签。在支持集Support Set中存在着K个样本。(x1,y1),(x2,y2),(x3,y3),....(xk,yk)(x_1,y_1),(x_2,y_2),(x_3,y_3),....(x_k,y_k)。此时给定一个查询数据xx^-,通过匹配网络,将xx^-与Support Set中的数据进行比较来确定该节点的类别标签。

进一步,我们定义查询点属于某一个分类的概率计算公式:
P(yx,S)P(y^-|x^-,S)
其中yy^-是预测标签,xx^-是查询点,S是支持集。在计算结束之后,我们来选择概率值最高的类别作为预测类别。

下面,我们来定义如何计算概率P,给出下面的计算公式:
y=i=1kα(x,xi)yiy^-=∑_{i=1}^kα(x^-,x_i)y_i
上述公式中,xi,yix_i,y_i表示的是支持集中的样本点和其对应的标签。而α表示的是一种注意力机制,即attention结果。attention的计算是通过softmax+cosine函数相结合计算出来的,其计算公式如下所示:
α(x,xi)=softmax(cosine(x,xi))α(x^-,x_i)=softmax(cosine(x^-,x_i))
一般来说,对于样本点x,xix^-,x_i,我们无法直接计算其cosine相似度。所以,我们在计算cosine相似度之前,需要对x,xix^-,x_i进行编码。此时,我们使用函数ffgg来分别学习xx^-xix_i来进行编码。则最后的计算公式为:
α(x,xi)=softmax(cosine(f(x),g(xi)))α(x^-,x_i)=softmax(cosine(f(x^-),g(x_i)))
将公式扩展之后,为:
α(x,xi)=ecosine(f(x),g(xi))j=1kecosine(f(x),g(xj))α(x^-,x_i)=\frac{e^{cosine(f(x^-),g(x_i))}}{∑_{j=1}^ke^{cosine(f(x^-),g(x_j))}}

在计算完Attention的过程之后,我们下一步需要计算的Attention结果和标签yiy_i的乘积,需要知道的是,我们计算出来的Attention结果是一个向量值,所以需要将y_i转换成向量的形式,这里我们将yiy_i转换成one-hot的形式。然后在执行乘法的操作。计算的结果为xx^-属于支持集中每一个分类的概率。再通过对每一个支持集样本的计算结果的求和过程,最后,我们选择概率最大的结果对应的类别标签作为最后的标签结果。

2.2 模型整体结构

元学习—关系网络和匹配网络

3 总结

上述介绍了关系网络和匹配网络的基本的计算过程和整体流程,两种方法的思想都比较简单。其中可以根据具体的要求来修改距离计算方法,Attention计算机制等等。

4 参考

  1. Hands-On Meta Learning with Python