prototypical networks for few-shot learning论文

论文:prototypical networks for few-shot learning
地址:https://arxiv.org/abs/1703.05175v2
code:https://github.com/jakesnell/prototypical-networks

摘要

针对小样本分类任务,作者提出了一种原型网络,分类器对于在训练集中未出现的新类别必须具有足够好的泛化性,每个新类仅有少量数据。原型网络学习的是度量空间,通过计算与梅格雷的原型表示的距离来进行分类。与当时的小样本学习的方法相比,原型网络反映了一种更简单的归纳偏置,在数据有限的情况下往往会取得很好的效果。论文中的分析表明,一些简单的设计决策,可以对涉及复杂结构选择和元学习(meta-learning)产生实质性的改进。论文进一步将原型网络扩展到zero-learning,在CUB数据集上取得了sota的效果。

引言

论文中要解决的就是小样本学习中由于数据量较少而导致的过拟合问题。论文中提出的原型网络就是使用神经网络将输入映射到一个度量空间,用类原型(ckc_k)来表示support set中的每一类。在分类任务中,将需要分类的数据映射到度量空间为xx,然后与类原型ckc_k比较距离,与那个近就属于那一类。如下图所示。
prototypical networks for few-shot learning论文

Prototypical network

Notaion

support set*有N个带有标签的数据。S={(x1,y1),,(xN,yN)}S = \{(\mathbf{x}_1,y_1),\dots,(\mathbf{x}_N,y_N)\},其中xiRD\mathbf{x}_i \in \mathbb{R}^D,D维特征向量。yi{1,,K}y_i \in \{1,\dots,K\}为对应的标签。SkS_k表示kk类的带标签的数据。

model

计算类原型ckc_kckRMc_k \in \mathbb{R}^M
embedding 函数:fϕ:RDRMf_\phi:\mathbb{R}^D \rightarrow \mathbb{R}^M,这里用的神经网络。
prototypical networks for few-shot learning论文
用softmax分类:
prototypical networks for few-shot learning论文
优化是用过SGD最小化J(ϕ)=logpϕ(y=kx)J(\phi) = -log p_\phi(y=k|\mathbf{x})
prototypical networks for few-shot learning论文
算法主要包含两个部分:
1、对episode中的每一类计算原型ckc_k,是由该类的所有数据的向量化表示求均值求得。
2、用query set中的数据对算法进行优化。

zero-shot Learning

零样本学习不同于少样本学习,其meta-data向量vk\mathbf{v}_k不是由训练集中的支持样本生成的,而是根据每个类的属性描述、原始数据等生成的。这些信息都是可以提取确定或者从原始数据中得到的。原形网络也能过灵活的转变成零样本学习,我们简单的定义ck=gθv(k)\mathbf{c}_k = g_\theta \mathbf{v}(k)为一个meta-data向量。

Experiments

Omniglot Few-shot Classification

Omniglot是包含50中字母的1623张手写字符数据集。每个episode包含60类,每类5个查询数据。
prototypical networks for few-shot learning论文

miniImageNet Few-shot Classification

minilmageNet数据集包含100个类别,每个类别中包含600个样本数据。其中64个类别数据作为训练集,16个类别数据作为验证集,20个类别数据作为测试集。该文分别使用30-way的episode对1-shot类和20-way的episode对5-shot的样本数据进行训练。在训练和测试时保持shot数目一致,query查询点的个数为每个类别15个。
prototypical networks for few-shot learning论文
实验分别对1-shot和5-shot的设置进行训练episode为5-way和20-way的训练,结果表明训练episode中设置更高的类别,对实验的结果有一定的增益效果。这是因为更高的way设置有助于网络进行更好的泛化,迫使模型在embedding空间做出更细粒度的决策。
prototypical networks for few-shot learning论文

CUB Zero-shot Classification

CUB数据集包含训练集包含100个类别,验证集包含50个类别,测试集包含50个类别。对于312维度的元向量,模型对鸟类的种类、颜色、羽毛等属性进行编码得到。训练episode的类别为50,每个类别的查询点为10个。
prototypical networks for few-shot learning论文

参考

小样本学习(few-shot learning)之——原形网络(Prototypical Networks)