小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

Prototypical Networks for Few-shot Learning


摘要:该文提出了一种可以用于few-shot learning的原形网络(prototypical networks)。该网络能识别出在训练过程中从未见过的新的类别,并且对于每个类别只需要很少的样例数据。原形网络将每个类别中的样例数据映射到一个空间当中,并且提取他们的“均值”来表示为该类的原形(prototype)。使用欧几里得距离作为距离度量,训练使得本类别数据到本类原形表示的距离为最近,到其他类原形表示的距离较远。测试时,对测试数据到各个类别的原形数据的距离做softmax,来判断测试数据的类别标签。

Prototypical Networks

Notation

在few-shot分类任务中,小样本学习(few-shot learning)之——原形网络(Prototypical Networks)小样本学习(few-shot learning)之——原形网络(Prototypical Networks)  为一组小规模的N标签的支持数据集。x是D维的原始数据的向量化表示,y为其对应的类别,Sk代表类别为k的数据集合。

Model

原形网络要为每个类别计算出一个原形表示Ck,通过一个embedding函数 小样本学习(few-shot learning)之——原形网络(Prototypical Networks)将维度D的样例数据映射到M维的空间上。类别的原形表示Ck是对支持集中的所有的向量化样例数据取均值得到的。

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

在测试时,原形网络使用softmax作用在query向量点到Ck的距离。

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

训练过程是通过随机梯度下降法最小化目标函数:小样本学习(few-shot learning)之——原形网络(Prototypical Networks)其中k为训练样本的真实标签。训练的episode为随机从训练集中选择的一个类子集,从这些类子集中选择一些样例数据作为支持(support set)集,其剩余的作为查询(query set)集。完整的loss计算过程我们参考伪码:

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

 

其中主要包含两个主要的步骤;

(1) 对pisode中的每个类别都计算出一个原形Ck,其Ck的计算是对该类中的所有支持数据的向量化表示取均值求得。

(2)优化类别中剩余的query点到原形的距离来训练模型。

 

Design Choices


Distance metric:通过以往工作和本文实验得出,使用欧几里得距离来作为距离度量会明显的优于使用余弦距离作为距离度量。

Episode composition: 以往的实验发现,在训练和测试时保持相同的episode设置往往会得出较好的结果。例如,我们在测试时期望使用5-way-1-shot的方式,那么我们训练时就要使得episode的设置为Nc为5、Ns为1,其中Nc代表从episode中选择的类别的个数,Ns代表每个类别中被选择为支持样例的个数。然而,在我们的实验中发现,使用比测试时更高的Nc(“way”)对模型是有益的。


Zero-Shot Learning

零样本学习不同于少样本学习,其meta-data向量Vk不是由训练集中的支持样本生成的,而是根据每个类的属性描述、原始数据等生成的。这些信息都是可以提取确定或者从原始数据中得到的。原形网络也能过灵活的转变成零样本学习,我们简单的定义Ck=gv(k)为一个meta-data向量。对于零样本学习和少样本学习我们详见下图:

 

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

Experiments

Omniglot Few-shot Classification

该文使用原形网络在Omniglot数据集上进行实验,使用欧几里得距离作为距离度量,分布在1-shot和5-shot进行实验。训练episode的设置为60个类别和每个类别5个query查询点。实验结果发现在训练和测试时保持相同的training-shot(即:支持样本数据)和episode使用更多的类别会使得实验效果更好。

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

下图展示了一个可视化的手写体识别实验,其中黑色点代表每种类别的原形,红色代表被错误分类的数据,红色箭头的指向为真实的类别。

                  小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

miniImageNet Few-shot Classification

minilmageNet数据集包含100个类别,每个类别中包含600个样本数据。其中64个类别数据作为训练集,16个类别数据作为验证集,20个类别数据作为测试集。该文分别使用30-way的episode对1-shot类和20-way的episode对5-shot的样本数据进行训练。在训练和测试时保持shot数目一致,query查询点的个数为每个类别15个。

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

实验分别对1-shot和5-shot的设置进行训练episode为5-way和20-way的训练,结果表明训练episode中设置更高的类别,对实验的结果有一定的增益效果。这是因为更高的way设置有助于网络进行更好的泛化,迫使模型在embedding空间做出更细粒度的决策。

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)CUB Zero-shot Classification

CUB数据集包含训练集包含100个类别,验证集包含50个类别,测试集包含50个类别。对于312维度的元向量,模型对鸟类的种类、颜色、羽毛等属性进行编码得到。训练episode的类别为50,每个类别的查询点为10个。

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)

论文链接:https://arxiv.org/abs/1703.05175