图网络之——Graph Memory Networks

Graph Memory Networks for Molecular Activity Prediction


Introduction

graph(图)作为一种数据结构,能够表达非常复杂的关系,比如社交关系网络,见下图。充分挖掘graph中蕴含的知识,是一个非常challenging的任务,已有的方法像 kernel-based method运用到大量数学知识,值得学习一下,但在deep learning爆炸的年代,任何东西不和deep learning结合一下都没法上得了台面,可以Google到,近几年有许多graph+deep learning的工作涌现出来,有时间我打算单独对一些代表性的工作做一个系统的总结,今天,就以该文章作为一个引子,先感受一下。

图网络之——Graph Memory Networks

该文章做的是一个基于graph的分子活性预测问题。molecule(分子)是由若干个atom(原子)通过化学键连接在一起的,我们可以把一个molecule看作一个graph,通过挖掘molecular内部的连接关系形成对molecular的整体认识,然后就可以用来做molecule的生物活性预测了。

Approach

本文使用了一种叫做Graph Memory Networks (Graph-Mem)的方法,整体来讲是这样一个过程,如下图所示:

  • 1)建立一个memory,memory中为每一个node提供一个存储空间,用来存储该node的表达,也可以叫embedding;
  • 2)使用一个controller(本文使用RNN作为controller),controller首先读入一个query表示接收一个预测任务,然后读入memory中所有的信息进行编码(通过加权求和的方式综合每一个node的内容);
  • 3)对于每一个node所对应的存储单元,controller结合已编码的信息,该node对应存储单元的当前信息,以及node邻接nodes所对应存储单元的信息,来共同决定如何更新memory内容,并写入memory,作为一次read和write操作;
  • 4)迭代地进行多次read和write操作,然后做一个预测任务的输出(在这里,输出使用了一个二分类,表示active和inactive)
  • 5)训练的过程只需要反复执行2)-4)步,每次执行做一次参数的更新即可。
    图网络之——Graph Memory Networks

下面具体讲述每一个过程。

Graph Memory

比较容易,就是一个矩阵,矩阵的每一列(行)存储一个node的表达,之后每一次迭代都会修改(write)memory中的内容。可以理解为一个embedding的过程。

Controller

controller负责refine memory cell中的内容,本文使用LSTM作为controller。具体做法如下:

  • 1)在t0时刻读入一个query
    图网络之——Graph Memory Networks
  • 2)之后的时间步则执行相同的操作,假设在t时刻,为了回答query,controller需要读入所有memory cell的summation,以注意力的方式读入,即赋予每一个cell不同的权重,其中mt就是summation,mt可以看做整个graph的表达,ht包含graph和query的信息,可以进一步用来做输出。使用attention机制的好处是,可以有选择性的决定哪些cell对于预测结果更重要。

    图网络之——Graph Memory Networks

    权重的获取需要综合考虑h和m,原理和典型的attention相似,公式如下:

    图网络之——Graph Memory Networks

  • 3)读入之后还需要更新memory来refine每一个cell中的表达,这时候不仅需要考虑到controller的隐状态,还需要考虑被更新的cell以及它的邻接点,只有这样才能把整个graph综合考虑在内。计算方式如下:

    图网络之——Graph Memory Networks

    注: 公式里的r代表边的类型,因为在实际应用中graph中往往会存在多种类型的边,这里每一种边代表一种化学键,每一种边使用一套参数。

  • 4)反复执行2)-3) T次,即反复的执行read和write操作,每一个node的表达都得以被refine,这样就获得了所谓的graph structure。最后输出分类结果。

注意: 更新h和m的时候,采用了skip-connection的方式,类似于下面的方式:

图网络之——Graph Memory Networks

GraphMem for multi-task learning

为了综合多个dataset的数据,使网络达到更好的训练效果,采用multi-task的方式训练。网络本身并没有任何变化,每一个dataset对应一个task,仅仅用不同的query区分不同的task。