link_prediction with GNN

This algorithm is not available in MAGE 1.14 and newer versions.

Link prediction can be defined as a problem where one wants to predict if there is a link between two nodes in the graph. It can be used for predicting missing or future links in the evolving graph. Using the notation G = (V, E) for a graph with nodes V and edges E and given two nodes v1 and v2, the link prediction algorithm tries to predict whether those two nodes will be connected, based on the node features and graph structure.

Lately, graph neural networks have been often used for node-classification and link-prediction problems. They are extremely useful in numerous interdisciplinary fields of work where it's important to incorporate domain-specific knowledge to capture more fine-grained relationships among the data. Such fields usually involve working with heterogeneous and large-scale graphs. GNNs iteratively update node representations by aggregating the representations of node neighbors and their representation from the previous iteration. Such properties make graph neural networks a great tool for various problems we in Memgraph encounter. If your graph is evolving in time, check TGN model (opens in a new tab) that Memgraph engineers have already developed.

In this includes the following features:

  • Support for homogeneous and heterogeneous graphs.
  • Support for disconnected graphs.
  • Its applicability to use it as a recommendation system.
  • A semi-inductive link prediction setup where a larger, updated graph is used for the inference.
  • An inductive link prediction setup in which training and inference graphs are different.
  • A transductive graph splitting (training and validation sets).
  • Graph attention layer aggregates information using an attention mechanism from the first-hop neighbourhood as introduced by Velickovic et al. (opens in a new tab).
  • GraphSAGE layer extends the usability of graph neural networks to large-scale graphs as introduced by Hamilton et al. (opens in a new tab).
  • mlp and dot predictors are used for combining node scores to edge scores.
  • ADAM and SGD optimizers are used for training neural networks.
  • Support for batch training.
  • Parallel graph sampling is done using multiple threads.
  • Negative graph sampling is a sampling where the final graph consists only of edges that don't exist
  • Evaluating the model's training performance using a variety of metrics like AUC, Precision, Recall, Accuracy and Confusion matrix.
  • Evaluating the model's recommendation performance with Precision@k, Recall@k, F1@k and Average Precision metrics.

When using link prediction be careful about:

  • Features of all nodes should be called the same (e.g., saved as 'features' property of nodes).
  • The model's performance on the validation set is obtained using transductive splitting mode, while inductive dataset split is not yet supported. You can find more information about graph splitting on slides of Graph Machine Learning course (opens in a new tab) offered by Stanford.
  • To improve performance, self-loop is added to each node with the edge-type set to self.
  • You can set the flag to automatically add reverse edges to each existing edge and hence, convert a *directed graph to a *bidirected one. If the source and destination nodes of the edge are the same, reverse edge type will be the same as the original edge type. Otherwise, the prefix rev_ will be added to the original edge type. See the FAQ part to further see why self-loops and reverse edges are very important in ML training and how you can get into problems if your graph is already undirected.

If you don’t have a vector at the beginning, which will be input for the GNN, then you can call one of the centrality algorithms (PageRank, betwenness centrality, etc.) and save all values into one vector, which will be input for the ML algorithm. That vector usually consists of numbers that describe the dataset in a way that is useful for a particular use case. If that vector is well-defined for the use case, you can expect better results, but if the vector is irrelevant to the use case, you’ll probably get results that are not that useful.

You might be interested in the following blog posts about link prediction:

For the underlying GNN training Memgraph uses the DGL library (opens in a new tab).


The link prediction module is used by calling the procedures in the following order:

  1. Set parameters by calling the set_model_parameters procedure.
  2. Train a model by calling the train procedure.
  3. Inspect training results (optional) by calling the get_training_results procedure.
  4. Predict the relationship between two vertices by calling the predict procedure.
  5. Find the most likely relationships by calling the recommend procedure.


Start by setting the parameters that are going to be used in the training. If the graph is heterogeneous (more than one edge type), all edges are used for creating the node’s neighbourhood but trained on only one edge type that can be set by the target_relation parameter. Then the model could distinguish supervision edges (edges used in prediction) from message passing edges (used for message aggregation).

For each gradient descent step, we select a mini-batch of nodes whose final representations at the L-th layer are to be computed. We then take all or some of their neighbours at the L−1 layer. This process continues until we reach the input. This iterative process builds the dependency graph starting from the output and working backwards to the input, as the figure below shows:

-- DGL docs

Take a look at the DGL mini-batch explanation (opens in a new tab) for more details.

In the case of homogeneous graph, target_relation will be automatically inferred.

The node_features_property parameter must also be sent by the user to specify where the original node features are saved. Those are needed by the graph neural networks to compute node embeddings.

All other parameters are optional.

When setting the model paramaters, if you specify the split_ratio 1.0, the model will train normally on a whole dataset without validating its performance on a validation set. However, if the split_ratio is a value between 0.0 and 1.0 but the graph is too small to have such a split, an exception will be thrown.

Self-loop edge is added to every node to improve link_prediction performance if add_self_loop is set to true. Self-loop edges are added only as edge_type self, not in any other way.

Here is the description of all parameters supported by link prediction that you can set by calling the set_model_parameters method:


hidden_features_sizemgp.List[int][16, 16]Defines the size of each hidden layer in the architecture. Input feature size is determined automatically while converting the original graph to the DGL compatible one.
layer_typestrgraph_attnSupported values are graph_sage and graph_attn.
num_epochsint100The number of epochs for model training.
optimizerstrADAMSupported values are ADAM and SGD.
learning_ratefloat0.01Optimizer's learning rate.
split_ratiofloat0.8The split ratio between the training and the validation set. There is no test dataset because it's assumed that the user first needs to create new edges in the original dataset to test a model on them.
node_features_propertystrfeaturesProperty name where the node features are saved.
device_typestrcpuDefines if the model will be trained using the CPU or Cuda GPU. To run on Cuda GPU, check if the system supports it with torch.cuda.is_available(), then set this flag to cuda.
console_log_freqint5Specifies how often results will be printed. This also directly specifies which results will be returned as training and validation results when calling the training method.
checkpoint_freqint5Select the number of epochs on which the model will be saved. The model is persisted on disc.
aggregatorstrmeanAggregator used in GraphSAGE model. Supported values are lstm, pool, mean and gcn.
metricsmgp.List[str][loss, accuracy, auc_score, precision, recall, f1, true_positives, true_negatives, false_positives, false_negatives]Metrics used to evaluate the training model on the validation set. Additionally, epoch information will always be displayed.
predictor_typestrdotType of the predictor. A predictor is used for combining node scores to edge scores. Supported values are dot and mlp.
attn_num_headsList[int][4, 1]GAT can support the usage of more than one head in each layer except the last one. The size of the list must be the same as the number of layers specified by the hidden_features_size parameter.
tr_acc_patienceint8Training patience specifies for how many epochs drop in accuracy on the validation set is tolerated before the training is stopped.
context_save_dirstrNonePath where the model and predictor will be saved every checkpoint_freq epochs.
target_relationstrNoneUnique edge type used for training. Users can provide only edge_type or tuple of the source node, edge type, dest_node if the same edge_type is used with more source-destination node combinations.
num_neg_per_pos_edgeint1Number of negative edges that will be sampled per one positive edge in the mini-batch training.
batch_sizeint256Batch size used in both training and validation procedure. It specifies the number of indexes in each batch.
sampling_workersint5Number of workers that will cooperate in the sampling procedure in the training and validation.
last_activation_functionstrsigmoidActivation function that is applied after the last layer in the model and before the predictor_type. Currently, only sigmoid is supported.
add_reverse_edgesboolFalseWhether the module should add reverse edges for each existing edge in the obtained graph. If the source and destination node are of the same type, edges of the same edge type will be created. If the source and destination nodes are different, then the prefix rev_ will be added to the previous edge type. Reverse edges will be excluded as message passing edges for corresponding supervision edges.
add_self_loopboolFalseFlag which specifies whether self loop edges will be added to every node in the graph with edge_type set to "self".


  • status: bool -> True if all parameters were successfully updated, False otherwise.
  • message: str -> OK if all parameters were successfully updated, Error message otherwise.


When calling the procedure, send only those parameters that need changing from their default values:

CALL link_prediction.set_model_parameters({num_epochs: 100, node_features_property: "features", tr_acc_patience: 8, target_relation: "CITES", batch_size: 256, last_activation_function: "sigmoid", add_reverse_edges: True})
YIELD status, message
RETURN status, message;


The train procedure is used to train a model and it doesn't take any input parameters.


  • training_results: List[Dict[str, float]] -> List of training results through epochs. Model's performance is evaluated every console_log_freq epochs.
  • validation results: List[Dict[str, float]] -> List of validation results through epochs. Model's performance is evaluated every console_log_freq epochs.


To train the module and get the training and validation results summarized through epochs, use the following query:

CALL link_prediction.train()
YIELD training_results, validation_results
RETURN training_results, validation_results;


The get_training_results() procedure is used to get performance data obtained from the last training. The results are in the same format as the results of the train() procedure. If there is no loaded model, an exception will be thrown.


  • training_results: List[Dict[str, float]] -> List of training results through epochs. Model's performance is evaluated every console_log_freq epochs.
  • validation results: List[Dict[str, float]] -> List of validation results through epochs. Model's performance is evaluated every console_log_freq epochs.


To get the performance data from the last training, use the following query:

CALL link_prediction.get_training_results()
YIELD training_results, validation_results;
RETURN training_results, validation_results;


The predict() procedure takes two arguments, src_vertex and dest_vertex, and predicts whether there is an edge between them or not. It supports an “actual” prediction scenario when the edge doesn’t exist and the user wants to predict whether there is an edge or not but also a scenario in which there is an edge between two vertices and the user wants to check the model’s evaluation.


  • src_vertex: mgp.Vertex -> Source vertex of the edge.
  • dest_vertex: mgp.Vertex -> Destination vertex of the edge.


  • score: mgp.Number -> Score between 0 and 1 that represents the probability of two nodes being connected.


To get a prediction about an existance of a relationship, use the following query:

MATCH (v1:PAPER {id: "ID_1"})
MATCH (v2:PAPER {id: "ID_2"})
CALL link_prediction.predict(v1, v2)
YIELD score
RETURN score;


The recommend() procedure can be used to recommend the best k nodes from dest_vertices to src_vertex. It is implemented efficiently using the max heap data structure. The best nodes are determined based on the edge scores. Metrics specific to recommendation systems (precision@k, recall@k, f1@k and average precision) are logged to the standard output. K is equal to the given min(k, length(dest_vertices), length(results)) where results are a list of all recommendations given by the model(classified as a positive example.)


  • src_vertex: mgp.Vertex → Source node.
  • dest_vertices: List[mgp.Vertex] → Destination nodes. If they are not of the same type, an exception is thrown.
  • k: int → Number of edges to recommend.


  • score: mgp.Number → Score between 0 and 1 that represents the probability of two nodes being connected.
  • recommendation: mgp.Vertex → A reference to the target node.


To get the k-nodes recommendations, use the following query:

MATCH (v1:Customer {id: "8779-QRDMV"})
MATCH (p:Plan)
WITH collect(p) AS all_plans, v1
CALL link_prediction.recommend(v1, all_plans, 5)
YIELD score, recommendation
RETURN v1, score, recommendation;


Loading the context means loading the model and the predictor. If the user specifies the path, the method will try to load it from there. Otherwise, context will be loaded from the default parameter specified in the link_prediction_parameters module.


  • path: str → Path to the folder where the model and the predictor are saved.


  • status: mgp.Any → True to indicate that execution went well.


To load the context, use the following query:

CALL link_prediction.load_context()
YIELD status 
RETURN status;


You can explicitly reset parameters whenever you want. Note, however, that parameters will be reset before the training even if not specified because of implementation reasons.


  • status: mgp.Any → True to indicate that method is successfully finished.


To reset the parameters, use the following query:

CALL link_prediction.reset_parameters()
YIELD status
RETURN status;


Memgraph extensively tested the model on the CORA (opens in a new tab) dataset and the Telecom recommendation dataset. To show you how the training performance could progress through epochs, here are the results for one of our basic models tried on the Cora dataset:



Database state

The database contains the following data:

Created with the following Cypher queries:

CREATE (v1:PAPER {id: 10, features: [1, 2, 3]});
CREATE (v2:PAPER {id: 11, features: [1.54, 0.3, 1.78]});
CREATE (v3:PAPER {id: 12, features: [0.5, 1, 4.5]});
CREATE (v4:PAPER {id: 13, features: [0.78, 0.234, 1.2]});
MATCH (v1:PAPER {id: 10}), (v2:PAPER {id: 11}) CREATE (v1)-[e:CITES {}]->(v2);
MATCH (v2:PAPER {id: 11}), (v3:PAPER {id: 12}) CREATE (v2)-[e:CITES {}]->(v3);
MATCH (v3:PAPER {id: 12}), (v4:PAPER {id: 13}) CREATE (v3)-[e:CITES {}]->(v4);
MATCH (v4:PAPER {id: 13}), (v1:PAPER {id: 10}) CREATE (v4)-[e:CITES {}]->(v1);

Set model parameters

To set the model parameters, use the following query:

CALL link_prediction.set_model_parameters({target_relation: ["PAPER", "CITES", "PAPER"], node_features_property: "features",
split_ratio: 1.0, predictor_type: "mlp", num_epochs: 100, hidden_features_size: [256], attn_num_heads: [1]}) 
YIELD status, message
RETURN status, message;

Train the module

To train the module, use the following query:

CALL link_prediction.train() YIELD training_results, validation_results
RETURN training_results, validation_results;

Train results:

| epoch_num          | accuracy           | auc_score          | loss               | precision          |
| 17                 | 0.833              | 0.906              | 0.428              | 1.0                |
| 18                 | 0.917              | 0.938              | 0.393              | 1.0                |
| 19                 | 0.833              | 0.938              | 0.365              | 0.75               |
| 20                 | 0.917              | 0.938              | 0.341              | 1.0                |
| 21                 | 0.917              | 0.938              | 0.315              | 1.0                |
| 22                 | 0.833              | 0.969              | 0.296              | 0.75               |
| 23                 | 0.917              | 1.0                | 0.277              | 1.0                |
| 24                 | 0.917              | 1.0                | 0.246              | 0.8                |
| 25                 | 0.917              | 1.0                | 0.233              | 0.8                |
| 26                 | 1.0                | 1.0                | 0.202              | 1.0                |

Predict relationships

To predict relationships, use the following query:

MATCH (v1:PAPER {id: 10})
MATCH (v2:PAPER {id: 12})
CALL link_prediction.predict(v1, v2)
YIELD score
RETURN score;


| score |
| 0.104 |


Why can I get into problems with reverse edges?

Having a reverse_edge in your dataset can be a problem if they are not excluded from message passing edges in the prediction of its opposite edge(supervision edge). The best thing you can do is have a directed graph and the module will automatically add reverse edges, if you specify add_reverse_edges in the set_model_parameters method, in a way that doesn't cause information flow.

What is a transductive dataset split?

The transductive dataset split assumes that the entire graph can be observed in all dataset splits. We distinguish four types of edges, and those are: validation, training, message passing and supervision edges.

The transductive dataset split is described in detail by prof. Jure Leskovec at one of its presentations for Graph ML course (opens in a new tab).