论文阅读笔记《Edge-Labeling Graph Neural Network for Few-shot Learning》
核心思想
本文采用基于图神经网络的算法实现了小样本学习任务,先前基于GNN的方法通常是基于节点标签框架,隐式地建立类内相似性和类间差异性的模型。而本文提出的边标签图卷积神经网络(Edge-labeling Graph Neural Network,EGNN)学习预测边标签而不是节点标签,这使其能够显式地表示类内相似性和类间差异性。这样描述还是很抽象,难以理解的,下面就直接介绍本文提出的模型。图模型包含三个部分:样本来自任务,而表示节点集合,表示边集合,和分别表示节点和边的特征,和分别表示节点和边的标签。
每个节点都对应一个样本,其初始值是来自一个嵌入式模型根据输入提取的特征向量。每个边缘特征是一个二维的向量,分别表示两个连接节点之间类间关系和类内关系的强度,其初始值如下
式中表示级联关系,EGNN的训练过程如下图所示
图中实线圆圈表示支持集样本,虚线圆圈表示查询集样本,不同颜色表示不同类别,方块表示两个节点之间的相似程度,颜色越深表示相似程度越高。整个网络分为L层,正向计算过程就是逐层的更新节点和边的特征。首先更新节点特征,其特征值是根据前一层的节点特征和边特征通过邻域聚合过程得到的,计算过程如下
式中表示节点特征变换网络,。类内聚合为目标节点提供了相似邻居的信息,而类间聚合则提供了不相似邻居的信息。而边特征则是根据更新后的节点特征,与度量网络来更新,计算过程如下
每个边特征的更新不仅考虑了对应节点的关系,而且考虑其他节点之间的关系。经过多次迭代更新之后,边标签的预测结果可以根据边特征获得,表示相邻的两个节点和来自同一类别的可能性,则节点所表示的样本属于第类的概率可表示为
式中表示如果内部等式成立则输出1,否则输出0。本文设计的网络还包含直推式(transductive)和非直推式(non-transductive)两种模式,直推式表示将所有的查询样本都同时放到一个图中,而非直推式则表示每次只添加一个查询样本。
实现过程
网络结构
本文设计的嵌入式网络f_{emb},节点特征变化网络,和度量网络的结构分别如图所示
损失函数
本文的损失函数是对边的预测值进行监督,计算方式如下
式中表示在第层网络的第个任务中的所有边预测值的集合,因为每层网络都可以输出预测结果,因此本文对每一层网络的输出都进行了监督,基础损失函数采用二元交叉熵损失函数。
训练策略
训练过程如下图所示
算法推广
本文设计的算法可以通过添加无标签样本实现半监督训练。
创新点
- 用边标签预测,取代了节点标签预测,显式地表示类内相似性与类间差异性,利用图神经网络实现小样本学习任务
算法评价
本文提出的算法相对于之前的基于图神经网络的算法而言,最大的变化就是使用边标签预测取代了节点标签预测,利用一个二维的边特征显式地表示了类内相似性与类间差异性,相对于其他基于图神经网络的小样本学习算法而言,本文在多个数据集上都取得了一定的进步。
如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。