GMNN: Graph Markov Neural Networks
摘要
本文研究了关系数据中的半监督对象分类,这是关系数据建模中的一个基本问题。统计关系学习(例如关系马尔可夫网络)和图神经网络(例如图卷积网络)的文献都对该问题进行了广泛的研究。在本文中,我们提出了结合两个方法优点的图马尔可夫神经网络(GMNN)。 GMNN使用条件随机场对对象标签的联合分布进行建模,可以使用变分EM算法对其进行有效训练。在E步中,一个图神经网络学习有效的对象表示形式,以近似对象标签的后验分布。在M步中,另一个图神经网络用于对局部标签依赖性进行建模。
算法网络代码结构
本人直接阅读的半监督部分的代码。
如图所示:
- 通过输入的特征inputs_q,标签target_q,邻接关系矩阵adj,标签取idx_train所拥有的部分,训练一个两层的GNN网络trainer_q。
- inputs_q通过trainer_q.predict得到preds,这是一个分布,每个维度是一个值,值越大的维度的索引越可能被选中,再编码成只有0/1的新编码,inputs_p和target_p。
- 将得到的新编码作为标签和特征作为网络trainer_p(两层GMNN网络)的输入,训练网络,更新trainer_p。
- inputs_p通过trainer_p.predict得到新preds和target_q,target_q将idx_train所拥有的部分替换为真实标签,得到新target_q。
- 将最初的inputs_q和新target_q作为特征和标签,再训练trainer_q,trainer_q得到了更新。