Inductive Representation Learning on Large Graphs
Inductive Representation Learning on Large Graphs
-
LINK: https://arxiv.org/abs/1706.02216
-
CLASSIFICATION: SPATAIL-BASED GCN
-
YEAR: Submitted on 7 Jun 2017 (v1), last revised 10 Sep 2018 (this version, v4)
-
FROM: NIPS 2017
-
WHAT PROBLEM TO SOLVE: Most existing approaches require that all nodes in the graph are present during training of the embeddings; these previous approaches are inherently transductive and do not naturally generalize to unseen nodes.
-
SOLUTION: Instead of training individual embeddings for each node, we learn a function that generates embeddings by sampling and aggregating features from a node’s local neighborhood.
-
CORE POINT:
-
VISUAL ILLUSTRATION
-
PRESENT WORK
- By incorporating node features in the learning algorithm, We simultaneously learn the topological structure of each node’s neighborhood as well as the distribution of node features in the neighborhood.
- Our algorithm can also be applied to graphs without node features.
- We train a set of aggregator functions that learn to aggregate feature information from a node’s local neighborhood. Each aggregator function aggregates information from a different number of hops, or search depth, away from a given node.
- We design an unsupervised loss function that allows GraphSAGE to be trained without task-specific supervision. We also show that GraphSAGE can be trained in a fully supervised manner.
-
GRAPHSAGE
-
Embedding generation (i.e., forward propagation) algorithm
Practically speaking we found that our approach could achieve high performance with K = 2 and S1 · S2 ≤ 500.
The intuition behind Algorithm 1 is that at each iteration, or search depth, nodes aggregate information from their local neighbors, and as this process iterates, nodes incrementally gain more and more information from further reaches of the graph. Instead of using full neighborhood sets in Algorithm 1, we uniformly sample a fixed-size set of neighbors.
The main idea is to sample all the nodes needed for the computation first.
Note that the sampling process in Algorithm 2 is conceptually reversed compared to the iterations over k in Algorithm 1: we start with the “layer-K” nodes (i.e., the nodes in B) that we want to generate representations for; then we sample their neighbors (i.e., the nodes at “layer-K-1” of the algorithm) and so on.
-
Learning the parameters of GraphSage
is a node that co-occurs near on fixed-length random walk, is the sigmoid function, is a negative sampling distribution, and defines the number of negative samples.
In cases where representations are to be used only on a specific downstream task, the unsupervised loss (Equation 1) can simply be replaced, or augmented, by a task-specific objective (e.g., cross-entropy loss).
-
Aggregator Architectures
-
Mean aggregator
The mean aggregator is nearly equivalent to the convolutional propagation rule used in the transductive GCN framework.
Concatenation can be viewed as a simple form of a “skip connection” between the different “search depths”, or “layers” of the GraphSAGE algorithm, and it leads to significant gains in performance.
-
LSTM aggregator
Compared to the mean aggregator, LSTMs have the advantage of larger expressive capability. However, it is important to note that LSTMs are not inherently symmetric (i.e., they are not permutation invariant), since they process their inputs in a sequential manner.
-
Pooling aggregator
The final aggregator we examine is both symmetric and trainable.
Intuitively, the multi-layer perceptron can be thought of as a set of functions that compute features for each of the node representations in the neighbor set. By applying the max-pooling operator to each of the computed features, the model effectively captures different aspects of the neighborhood set.
In principle, any symmetric vector function could be used in place of the max operator (e.g., an element-wise mean). We found no significant difference between max- and mean-pooling in developments test.
-
-
-
EXPERIMENTS
For all the GraphSAGE variants we used rectified linear units as the non-linearity and set K = 2 with neighborhood sample sizes S1 = 25 and S2 = 10.
-
3 benchmark tasks
- Classifying academic papers into different subjects using the Web of Science citation dataset.
- Classifying Reddit posts as belonging to different communities
- Classifying protein functions across various biological protein-protein interaction (PPI) graphs.
-
4 baselines
- A random classifer.
- A logistic regression feature-based classifier (that ignores graph structure).
- The DeepWalk algorithm as a representative factorization-based approach.
- A concatenation of the raw features and DeepWalk embeddings.
-
Citation data
For features, we used node degrees and processed the paper abstracts according Arora et al.’s sentence embedding approach, with 300-dimensional word vectors trained using the GenSim word2vec implementation.
-
Reddit data
For features, we use off-the-shelf 300-dimensional GloVe CommonCrawl word vectors; for each post, we concatenated (i) the average embedding of the post title, (ii) the average embedding of all the post’s comments (iii) the post’s score, and (iv) the number of comments made on the post.
-
Protein-protein interactions
-
Prediction results
-
Runtime and parameter sensitivity
For the GraphSAGE variants, we found that setting K = 2 provided a consistent boost in accuracy of around 10-15%, on average, compared to K = 1; however, increasing K beyond 2 gave marginal returns in performance (0-5%) while increasing the runtime by a prohibitively large factor of 10-100×, depending on the neighborhood sample size. We also found diminishing returns for sampling large neighborhoods (Figure 2.B).
-
A particularly interesting direction for future work
Exploring non-uniform neighborhood sampling functions, and perhaps even learning these functions as part of the GraphSAGE optimization.
-
-
-
EXISTING PROBLEMS: Incapable of distinguishing different graph structures
-
IMPROVEMENT IDEAS: 404