Skip to main content

link_prediction_with_gnn

docs-source

Abstract​

Link prediction can be defined as a problem where one wants to predict if there is a link between two nodes in the graph. It can be used for predicting missing or future links in the evolving graph. Using the notation G = (V, E) for a graph with nodes V and edges E and given two nodes v1 and v2, the link prediction algorithm tries to predict whether those two nodes will be connected, based on the node features and graph structure. Lately, graph neural networks have been often used for node-classification and link-prediction problems. They are extremely useful in numerous interdisciplinary fields of work where is important to incorporate domain-specific knowledge to capture more fine-grained relationships among the data. Such fields usually involve working with heterogeneous and large-scale graphs. GNNs iteratively update node representations by aggregating the representations of node neighbors and their representation from the previous iteration. Such properties make graph neural networks a great tool for various problems we in Memgraph encounter. If your graph is evolving in time, check TGN model that Memgraph engineers have already developed.

Blog Posts​

The following blog posts explain how we tried to apply link prediction:

About the query module​

In this module you can find support for the following interesting features:

  • support for both homogeneous and heterogeneous graphs
  • support for disconnected graphs
  • its applicability to use it as a recommendation system
  • a semi-inductive link prediction setup where a larger, updated graph is used for the inference
  • an inductive link prediction setup in which training and inference graphs are different
  • transductive graph splitting (training and validation sets)
  • Graph attention layer aggregates information using an attention mechanism from the first-hop neighbourhood. Introduced by Velickovic et al.
  • GraphSAGE layer extends the usability of graph neural networks to large-scale graphs. Introduced by Hamilton et al.
  • mlp and dot predictors are used for combining node scores to edge scores
  • ADAM and SGD optimizers are used for training neural networks
  • support for batch training
  • parallel graph sampling is done using multiple threads
  • negative graph sampling is a sampling where the final graph consists only of edges that don't exist
  • evaluating the model's training performance using a variety of metrics like AUC, Precision, Recall, Accuracy, Confusion matrix
  • evaluating the model's recommendation performance with [email protected], [email protected], [email protected] and Average Precision metrics

If you want to try-out our implementation, head to github/memgraph/mage and find python/link_prediction.py. Feel free to give us a ⭐ if you like the code. The easiest way to test link-prediction is by downloading Memgraph Platform and using some of the preloaded datasets in Memgraph Lab.

There are some things you should be careful about when using link prediction:

  • features of all nodes should be called the same (e.g saved as 'features' property in Memgraph)
  • model's performance on the validation set is obtained using transductive splitting mode, while inductive dataset split is not yet supported. You can find more information about graph splitting on slides of Graph Machine Learning course offered by Stanford.
  • to improve performance, self-loop is added to each node with the edge-type set to self
  • the user can set the flag to automatically add reverse edges to each existing edge and hence, convert a directed graph to a bidirected one. If the source and destination nodes of the edge are the same, reverse edge type will be the same as the original edge type. Otherwise, the prefix rev_ will be added to the original edge type. See the FAQ part to further see why are self-loops and reverse edges very important in ML training and how you can get into problems if your graph is already undirected πŸ€”

Feel free to open a GitHub issue or start a discussion on Discord if you want to speed up development.

Usage​

The following procedure is expected when using link prediction module:

  • set parameters by calling set_model_parameters function
  • train a model by calling train function
  • inspect training results (optional) by calling get_training_results function
  • predict the relationship between two vertices by calling predict or
  • call the recommend function to find the most likely relationships

Implementation details​

For the underlying GNN training we use the DGL library.

Fast and memory-efficient message passing primitives for training Graph Neural Networks. Scale to giant graphs via multi-GPU acceleration and distributed training infrastructure.

-- DGL team

Splitting the dataset​

If the user specifies split_ratio 1.0, the model will train normally on a whole dataset without validating its performance on a validation set. However, if the user-defined split_ratio is a value between 0.0 and 1.0 but the graph is too small to have such a split, an exception will be thrown.

Self-loops​

Self-loop edge is added to every node to improve link_prediction performance if specified by the user. Self-loop edges are added only as edge_type self, not in any other way, and to enable this, a custom module has been added.

Batch training​

In heterogeneous graphs, all edges are used for creating the node’s neighbourhood but trained on only one edge type that can be set by the user.

For each gradient descent step, we select a mini-batch of nodes whose final representations at the L-th layer are to be computed. We then take all or some of their neighbours at the Lβˆ’1 layer. This process continues until we reach the input. This iterative process builds the dependency graph starting from the output and working backwards to the input, as the figure below shows:

-- DGL docs

The reader is encouraged to take a look at the DGL mini-batch explanation for more details.

Procedures​

The link prediction module is organized as a stateful module in which the user can run several methods one after another without losing the context. The user should start with setting the parameters that are going to be used in the training. If the graph is heterogeneous (more than one edge type), target_relation parameter must be set so the model could distinguish supervision edges (edges used in prediction) from message passing edges (used for message aggregation). In the case of homogeneous graph, target relation will be automatically inferred. Node_features_property must also be sent by the user to specify where are saved original node features. Those are needed by graph neural networks to compute node embeddings. All other parameters are optional.

set_model_parameters()​

Here is the description of all parameters supported by link prediction that you can set by calling the set_model_parameters method:

Input:​

NameTypeDefaultDescription
hidden_features_sizemgp.List[int][16, 16]Defines the size of each hidden layer in the architecture. Input feature size is determined automatically while converting the original graph to the DGL compatible one.
layer_typestrgraph_attnSupported values are graph_sage and graph_attn.
num_epochsint100The number of epochs for model training.
optimizerstrADAMSupported values are ADAM and SGD.
learning_ratefloat0.01Optimizer's learning rate.
split_ratiofloat0.8The split ratio between the training and the validation set. There is no test dataset because it's assumed that the user first needs to create new edges in the original dataset to test a model on them.
node_features_propertystrfeaturesProperty name where the node features are saved.
device_typestrcpuDefines if the model will be trained using the CPU or Cuda GPU. To run on Cuda GPU, check if the system supports it with torch.cuda.is_available(), then set this flag to cuda.
console_log_freqint5Specifies how often results will be printed. This also directly specifies which results will be returned as training and validation results when calling the training method.
checkpoint_freqint5Select the number of epochs on which the model will be saved. The model is persisted on disc.
aggregatorstrmeanAggregator used in GraphSAGE model. Supported values are lstm, pool, mean and gcn.
metricsmgp.List[str][loss, accuracy, auc_score, precision, recall, f1, true_positives, true_negatives, false_positives, false_negatives]Metrics used to evaluate the training model on the validation set. Additionally, epoch information will always be displayed.
predictor_typestrdotType of the predictor. A predictor is used for combining node scores to edge scores. Supported values are dot and mlp.
attn_num_headsList[int][4, 1]GAT can support the usage of more than one head in each layer except the last one. The size of the list must be the same as the number of layers specified by the hidden_features_size parameter.
tr_acc_patienceint8Training patience specifies for how many epochs drop in accuracy on the validation set is tolerated before the training is stopped.
context_save_dirstrNonePath where the model and predictor will be saved every checkpoint_freq epochs.
target_relationstrNoneUnique edge type used for training. Users can provide only edge_type or tuple of the source node, edge type, dest_node if the same edge_type is used with more source-destination node combinations.
num_neg_per_pos_edgeint1Number of negative edges that will be sampled per one positive edge in the mini-batch training.
batch_sizeint256Batch size used in both training and validation procedure. It specifies the number of indices in each batch.
sampling_workersint5Number of workers that will cooperate in the sampling procedure in the training and validation.
last_activation_functionstrsigmoidActivation function that is applied after the last layer in the model and before the predictor_type. Currently, only sigmoid is supported.
add_reverse_edgesboolFalseWhether the module should add reverse edges for each existing edge in the obtained graph. If the source and destination node are of the same type, edges of the same edge type will be created. If the source and destination nodes are different, then the prefix rev_ will be added to the previous edge type. Reverse edges will be excluded as message passing edges for corresponding supervision edges.

Output:​

  • status: bool -> True if all parameters were successfully updated, False otherwise.
  • message: str -> OK if all parameters were successfully updated, Error message otherwise.

Only those parameters that need changing from their default values are sent when calling the procedure:

CALL link_prediction.set_model_parameters({num_epochs: 100, node_features_property: "features", tr_acc_patience: 8, target_relation: "CITES", batch_size: 256, last_activation_function: "sigmoid", add_reverse_edges: True})
YIELD status, message
RETURN status, message;

train()​

The train method doesn't take any parameters, so it is very simple to use.

Output:​

  • training_results: List[Dict[str, float]] -> List of training results through epochs. Model's performance is evaluated every console_log_freq epochs.
  • validation results: List[Dict[str, float]] -> List of validation results through epochs. Model's performance is evaluated every console_log_freq epochs.

You can just call

CALL link_prediction.train() 
YIELD training_results, validation_results
RETURN training_results, validation_results;

to get training and validation results summarized through epochs.

get_training_results()​

The get_training_results method is used when the user wants to get performance data obtained from the last training. It is in the same form as a result of calling the training method. If there is no loaded model, the exception will be thrown.

CALL link_prediction.get_training_results()
YIELD training_results, validation_results;
RETURN training_results, validation_results;

Output:​

  • training_results: List[Dict[str, float]] -> List of training results through epochs. Model's performance is evaluated every console_log_freq epochs.
  • validation results: List[Dict[str, float]] -> List of validation results through epochs. Model's performance is evaluated every console_log_freq epochs.

predict()​

The predict method takes two arguments, src_vertex and dest_vertex, and predicts whether there is an edge between them or not. It supports an β€œactual” prediction scenario when the edge doesn’t exist and the user wants to predict whether there is an edge or not but also a scenario in which there is an edge between two vertices and the user wants to check the model’s evaluation.

Input​

  • src_vertex: mgp.Vertex -> Source vertex of the edge
  • dest_vertex: mgp.Vertex -> Destination vertex of the edge.

Output​

  • score: mgp.Number -> Score between 0 and 1 that represents the probability of two nodes being connected.
MATCH (v1:PAPER {id: "ID_1"})
MATCH (v2:PAPER {id: "ID_2"})
CALL link_prediction.predict(v1, v2)
YIELD score
RETURN score;

recommend()​

The recommend method can be used to recommend the best k nodes from dest_vertices to src_vertex. It is implemented efficiently using the max heap data structure. The best nodes are determined based on the edge scores. Metrics specific to recommendation systems ([email protected], [email protected], [email protected] and average precision) are logged to the standard output. K is equal to the given min(k, length(dest_vertices), length(results)) where results are a list of all recommendations given by the model(classified as a positive example.)

Input​

  • src_vertex: mgp.Vertex β†’ Source node.
  • dest_vertices: List[mgp.Vertex] β†’ destination nodes. If they are not of the same type, an exception is thrown.
  • k: int β†’ Number of edges to recommend.

Output​

  • score: mgp.Number β†’ Score between 0 and 1 that represents the probability of two nodes being connected.
  • recommendation: mgp.Vertex β†’ A reference to the target node.
MATCH (v1:Customer {id: "8779-QRDMV"})
MATCH (p:Plan)
WITH collect(p) AS all_plans, v1
CALL link_prediction.recommend(v1, all_plans, 5)
YIELD score, recommendation
RETURN v1, score, recommendation;

load_context()​

Loading the context means loading the model and the predictor. If the user specifies the path, the method will try to load it from there. Otherwise, context will be loaded from the default parameter specified in the link_prediction_parameters module.

Input​

  • path: str β†’ Path to the folder where the model and the predictor are saved.

Output​

  • status: mgp.Any β†’ True to indicate that execution went well.
CALL link_prediction.load_context() YIELD * RETURN *;

reset_parameters()​

You can explicitly reset parameters whenever you want. Note, however, that parameters will be reset before the training even if not specified because of implementation reasons.

Output​

  • status: mgp.Any β†’ True to indicate that method is successfully finished.
CALL link_prediction.reset_parameters() YIELD * RETURN *;

Results​

We extensively tested our model on the CORA dataset and the Telecom recommendation dataset. To show you how the training performance could progress through epochs, here are the results for one of our basic models tried on the Cora dataset:

epoch_numAUCaccuracyprecisionrecallf1
10.640.5940.6130.4940.547
20.7810.6960.7110.6630.686
30.7980.7290.7520.6820.715
40.7540.6860.7160.6170.663
50.7890.7110.7150.70.707
60.8130.7560.7420.7840.763
70.8840.7720.7640.7910.775
80.8590.7750.7810.7660.773
90.8710.8050.8220.7770.798
100.8320.7590.7760.7290.752

Example​

FAQ​

Why can I get into problems with reverse edges?​

Having a reverse_edge in your dataset can be a problem if they are not excluded from message passing edges in the prediction of its opposite edge(supervision edge). The best thing you can do is have a directed graph and the module will automatically add reverse edges, if you specify add_reverse_edges in the set_model_parameters method, in a way that doesn't cause information flow.

What is a transductive dataset split?​

The transductive dataset split assumes that the entire graph can be observed in all dataset splits. We distinguish four types of edges, and those are: validation, training, message passing and supervision edges.

The transductive dataset split is described in detail by prof. Jure Leskovec at one of its presentations for Graph ML course.