图神经网络Graph Neural Network

前言

在我上一篇博客,介绍基于random walk的节点表示方式,该方法的主要是思想是以one-hot的形式,经过Embedding层得到node vector,然后优化以下的似然函数来得到最优的Embedding Matrix

maxuVlogP(NR(u)zu)max \sum_{u \in V} logP(N_R(u)|z_u)

该方法有很多缺点

  • 需要V|V|大小的空间
  • 参数没有共享,一个节点对应一个embedding值
  • 图通常需要用到节点特征,该方法没有办法结合节点特征

本文将会介绍基于GNN的表示方式,尽可能解决以上的问题。

一种简单的方法

图像有CNN,序列问题有RNN,但是对于图结构来说,这些模型都不适用,图的节点数量不固定,通常都会有很复杂的拓扑结构,无论是CNN还是RNN都没有办法处理这样动态的数据结构,那么如何解决这个问题呢?

最简单的方法就是采用图的邻接矩阵,并且把节点的特征拼接进来,再把拼接后的数据喂给一个神经网络。

图神经网络Graph Neural Network

该方法虽然可行,但是存在一些缺点

  • O(N)O(N)的参数量
  • 训练好的模型不适用于不同大小的图
  • 没有考虑图中的顺序

深度学习中的图

首先我们定义几个符号

  • G 图
  • V 图中的节点
  • A 邻接矩阵
  • X 节点特征XRm×VX \in \Reals^{m×|V|}

从之前的博文中,我们知道,一个节点的embedding是由其邻居节点决定的,那么我们是否可以使用神经网络来聚合其邻居节点的信息呢?答案是肯定的,我们给每一个节点根据其邻居定义一个计算图,如下图所示。

图神经网络Graph Neural Network

这样的结构让节点在每一层都有其embedding表示,其中第一层是每个节点的特征即xx,第k层是经过k度后的节点embedding信息(层约深,获取到的全局信息就越多)

图中灰色部分表示的是聚合操作,那么这个聚合操作到底是什么样的呢?我们这里介绍两种方法

  • 平均邻居信息
  • 采用神经网络

Average neighbor messages

先看第一种,取均值。假如我们要计算vv节点的embedding,首先始化第0层的embedding等于其节点特征

hv0=Xvh^0_v=X_v

然后计算下一层的emebdding,可以看到,计算分为了两部分,第一块是对vv节点的邻居节点的上一层的embedding取均值,然后与vv节点的上一层的embedding取加权平均数,这里的两个加权值就是我们需要训练的参数,最后外面加一层非线性变化,注意,这里的σ\sigma是指非线性变化,不一定是sigmoid函数

hvk=σ(WkuN(v)huk1N(v)+Bkhvk1)h_v^k=\sigma(W_k \sum_{u \in N(v)}\frac{h_u^{k-1}}{|N(v)|} + B_k h_v^{k-1})

最后节点v的embeddding就等于最后一层的embedding

zv=hvKz_v=h^K_v

知道了怎么计算embedding,那下一步就是考虑怎么训练模型,怎么定义损失函数了,这里有两种方法

  • 非监督学习,和random walk那篇博客讲的方法一样,不再赘述
  • 监督学习,采用节点分类任务训练embedding

这里着重说下监督学习,我们对所有的节点进行一个分类任务,假如是二分类,那么损失函数就是一个交叉熵损失,其中yvy_v是节点的真实标签,θ\theta是分类任务的训练参数。

L=vVyvlog(σ(zvTθ))+(1yv)log(1σ(zvTθ))L = \sum_{v\in V}y_v log(\sigma(z_v^T \theta))+(1-y_v)log(1-\sigma(z_v^T \theta))

训练的过程,我们可以把多个节点的embedding作为一个batch,如下图是3个节点对应的embedding

图神经网络Graph Neural Network

参数主要有三部分,分类任务的θ\theta、生成embedding的WKW_KBkB_k,这些参数对于不同的节点都是共享的

图神经网络Graph Neural Network

训练好的模型,只要是同样的场景都可以使用,例如我们对某有机物A构建了其蛋白质结构图,同样适用于有机物B。在工业场景中,图中新加一个节点也是很常见的情况,特别是社交网络这样的图,我们依旧不需要重新训练模型,直接对新加的节点使用训练好的神经网络进行embedding的生成即可。

GraphSAGE

上一节我们介绍了采用均值的方式来聚合邻居信息,这一节我们来看一个更好的方法。

首先我们要明确,需要优化的对象是

uN(v)huk1N(v)\sum_{u \in N(v)} \frac{h_u^{k-1}}{|N(v)|}

也就是说我们应该考虑使用其它聚合方式替换取均值即

AGG(huk1,uN(v))AGG({h_u^{k-1}, \forall u \in N(v)})

未完待续