Prototypical Networks for Few-shot Learning.(用于少样本分类的原型网络)

**

Prototypical Networks for Few-shot Learning.(用于少样本分类的原型网络)

**

摘要

文章提出了一种用于少样本分类的原型网络,其中分类器必须可以推广(泛化)到在训练集里面没有见过的新的类别,并且每个新的类别只有很少一部分样本。原型网络学习一个度量空间,执行分类只需要简单的计算到每个类的原型表示的距离。与其他方法的主要一点不同是原型网络反映了一种在数据集不足情况下的归纳偏差。最后一点是原型网络可以用于zero-shot learning(零样本学习)。

简单的设计选择带来了实质性的改进

本质思想

使用一个神经网络学习一个 embedding 将输入映射到一个映射空间里面,而每个类的原型就是简单的这个类的 support set 所有样本(support set 的定义参见后后文)的均值。分类的执行:将需要分类的样本用学习到的映射函数映射到嵌入空间里面,然后离他最近的原型就是我们的预测类。

介绍

  1. few-shot learning 问题的本质之一是样本很少,一般每个类的样本少于20个。另外一个本质是这些类是全新的类。few-shot learning 问题的分类器必须调整以适应新的类,并且新的类中每个类的样本很少。传统的方法是在新的数据上重新训练新的模型,但是由于样本太少,往往会导致过拟合 (overfitting)。但是我们人类是可以做到few-shot learning的,因此迈向普世的人工智能,few-shot leanring 问题是一个重要问题。
  2. 文章假设:由于样本有限,我们的分类器会有一个归纳偏差 (inductive bias)。 原型网络基于这样的想法:存在一种 embedding,使得其中每个类的点(可能多个)聚集在每个类的单个原型表示周围。
  3. 方法概述:使用一个神经网络学习一个 embedding 将输入映射到一个映射空间里面,而每个类的原型就是简单的这个类的 support set 所有样本(support set 的定义参见后后文)的均值。分类的执行:将需要分类的样本用学习到的映射函数映射到嵌入空间里面,然后离他最近的原型就是我们的预测类。zero-shot learning 类似,只不过每个类没有标记 (label),而是一个关于类的高级描述 (high-level description of the class)。

方法

  1. 符号定义:support set 是我们具有的 NN 个标记的样本 S={(x1,y1),,(xN,yN)}S=\left\{\left(\mathbf{x}_{1}, y_{1}\right), \ldots,\left(\mathbf{x}_{N}, y_{N}\right)\right\},其中 xiRD\mathbf{x}_{i} \in \mathbb{R}^{D} 是一个 DD维的向量,yi{1,,K}y_{i} \in\{1, \ldots, K\} 使对应的类,SkS_{k} 表示类别为 kk 的样本的集合。
  2. 模型:首先一个映射函数 (embedding):fϕ:RDRMf_{\phi} : \mathbb{R}^{D} \rightarrow \mathbb{R}^{M} 参数为 ϕ\phi,将输入空间映射到嵌入空间 (嵌入空间的维度为 MM),然后计算每个类的原型: ck=1Sk(xi,yi)Skfϕ(xi)\mathbf{c}_{k}=\frac{1}{\left|S_{k}\right|} \sum_{\left(\mathbf{x}_{i}, y_{i}\right) \in S_{k}} f_{\phi}\left(\mathbf{x}_{i}\right)。这个公式很简单,就是嵌入空间中所有 kk 类的所有点先求和在除以点的个数(就是简单的均值)。执行分类:距离函数 d:RM×RM[0,+)d : \mathbb{R}^{M} \times \mathbb{R}^{M} \rightarrow[0,+\infty) 根据两个 MM 维的向量,使用距离函数 dd,求出距离。对于一个需要分类的样本点 X\mathbf{X}X\mathbf{X} 属于类别 kk 的概率为:pϕ(y=kx)=exp(d(fϕ(x),ck))kexp(d(fϕ(x),ck))p_{\phi}(y=k | \mathbf{x})=\frac{\exp \left(-d\left(f_{\phi}(\mathbf{x}), \mathbf{c}_{k}\right)\right)}{\sum_{k^{\prime}} \exp \left(-d\left(f_{\phi}(\mathbf{x}), \mathbf{c}_{k^{\prime}}\right)\right)},即与原型距离的相反数的softmax(归一化),取相反数的原因是相距最小的原型最有可能使这个样本点对应的类。
  3. 学习过程是通过 SGD(Stochastic Gradient Descent, 随机梯度下降) 最小化正确类别 kk 的负对数概率 J(ϕ)=logpϕ(y=kx)J(\phi)=-\log p_{\phi}(y=k | \mathbf{x})。最小化 J(ϕ)=logpϕ(y=kx)J(\phi)=-\log p_{\phi}(y=k | \mathbf{x}) 等价于最大化 logpϕ(y=kx)\log p_{\phi}(y=k | \mathbf{x}),等价于最大化pϕ(y=kx)p_{\phi}(y=k | \mathbf{x})
  4. 训练剧集 (training episodes) 的构造,这也是我第一次在论文中看到以算法的形式介绍如何构建 meta-training task
    Prototypical Networks for Few-shot Learning.(用于少样本分类的原型网络)
    首先 NN 使训练集样本的数量,KK是训练集中样本的类,NCKN_{C} \leq K 是每次 training episodes 中类别个数,NSN_{S} 是 support set 中每个类的的个数,NQN_{Q} 使 query set 中每个类的的个数。RANDOMSAMPLE (S,N)(S, N) 表示在 SS 中随机均匀采样 NN 个样本,没有替换。
    Input:训练集 D={(x1,y1),,(xN,yN)}\mathcal{D}=\left\{\left(\mathbf{x}_{1}, y_{1}\right), \ldots,\left(\mathbf{x}_{N}, y_{N}\right)\right\},其中 yi{1,,K}y_{i} \in\{1, \ldots, K\}Dk\mathcal{D}_{k} 表示一个包含所有 yi=ky_{i}=k 的元素 (xi,yi)\left(\mathbf{x}_{i}, y_{i}\right) 集合 D\mathcal{D}
    Output: training episodes 的损失JJ
    过程:
    VV \leftarrow RANDOMSAMPLE ({1,,K},NC)\left(\{1, \ldots, K\}, N_{C}\right) 从K个类中随机取出 NCN_{C} 个类作为这个 training episode 的类别。
    对其中的每个类别 kkSkS_{k} \leftarrow RANDOMSAMPLE (DVk,NS)\left(\mathcal{D}_{V_{k}}, N_{S}\right),在所有类别为 kk 的样本中采样 NSN_{S}个样本作为这个类的 support。
    对其中的每个类别 kkQkQ_{k} \leftarrow RANDOMSAMPLE (DVk\Sk,NQ)\left(\mathcal{D}_{V_{k}} \backslash S_{k}, N_{Q}\right),在所有类别为 kk 的样本中采样 NQN_{Q}个样本作为这个类的 query。
    使用每个类的 support set 计算每个类的原型:ck1NC(xi,yi)Skfϕ(xi)\mathbf{c}_{k} \leftarrow \frac{1}{N_{C}} \sum_{\left(\mathbf{x}_{i}, y_{i}\right) \in S_{k}} f_{\phi}\left(\mathbf{x}_{i}\right)这里我觉得原论文有错:我用红色标记的地方,分母应该是 Sk\left|S_{k}\right|
    J0J \leftarrow 0:初始化 loss 为 0 。
    for kk in {1,,NC}\left\{1, \ldots, N_{C}\right\} do: for (x,y)(\mathbf{x}, y) in QkQ_{k} do: JJ+1NCNQ[d(fϕ(x),ck))+logkexp(d(fϕ(x),ck))]J \leftarrow J+\frac{1}{N_{C} N_{Q}}\left[d\left(f_{\phi}(\mathbf{x}), \mathbf{c}_{k}\right)\right)+\log \sum_{k^{\prime}} \exp \left(-d\left(f_{\phi}(\mathbf{x}), \mathbf{c}_{k}\right)\right) ]:叠加每个类的 query set 节点的损失。

进一步分析 (混合密度估计,距离度量的选择,训练剧集的构造)以及实验部分请参考原论文

这里讲一下我对归纳偏差的理解:对于 support set 中每个类的样本点不止一个的情况下,取任意一点作为这个类的基准都不合适,原型网络的思想是求平均,这样考虑到了每个样本点以及每个样本点的偏差。我们不妨认为其实每个样本点和这个类的原型表示点都有一点距离(偏差),求平均就是归纳偏差。当然一个想法就是不求平均,用别的方法找这个累的原型表示点,如果效果提升并且可以给出一定的理论证明,这是一个不错的future work 方向。另外距离欧式距离不一定的最好的选择,论文在欧氏距离和余弦距离中做出选择使根据实验得到的,至于为什么还有有没有一个更好的甚至最好的度量也是一个方向(后面会介绍的Relation Network就是从这个方向入手的)。

论文: https://arxiv.org/pdf/1703.05175.pdf.