Skip to main content




Node classification is the problem of finding out the right label for a node based on its neighbors’ labels and structure similarities.

About the query module​

This query module contains all necessary functions you need to train GNN model on Memgraph.

The node_classification module supports as follows:

  • homogeneous and heterogeneous graphs
  • multiple-label and multi-edge-type graphs
  • any-size datasets
  • the following model architectures:
    • Graph Attention with Jumping Knowledge
    • multiple versions of Graph attention networks (GAT)
    • GraphSAGE
  • early stopping
  • calculation of various metrics
  • predictions for specified nodes
  • model saving and loading
  • recommendation system use cases

The easiest way to test node_classification is by downloading Memgraph Platform and using some of the preloaded datasets in Memgraph Lab. If you want to explore our implementation, jump to github/memgraph/mage and find python/ Feel free to give us a ⭐ if you like the code.

Feel free to open a GitHub issue or start a discussion on Discord if you want to speed up development.


Load dataset in Memgraph, call set_model_parameters, and start training your model. When training is done, query module will save models. Afterwards, you can test modules on other data (which model has not already seen for example) and inspect the results! The module reports the mean average precision for every batch training or evaluation epoch.

To summarize basic node classification workflow is as follows:

  • load data to Memgraph
  • set parameters by calling set_model_parameters() function. Be sure that node_features property on nodes are in place.
  • call train() function
  • inspect training results (optional) by calling get_training_data() function
  • optionally use save_model() and load_model()
  • predict node class by calling predict() procedure

This MAGE module is still in its early stage. We intend to use it only for exploring or learning about node classification. If you want it to be production-ready, make sure to either open a GitHub issue or drop us a comment on Discord.



The function initializes all global variables. You can change global variables via params dictionary. Procedure checks if variables in params are defined appropriately. If so, map of default global parameters is overriden with user defined dictionary params. After that procedure executes previously defined functions declare_globals and declare_model_and_data and sets each global variable to some value.


  • params: (mgp.Map, optional): User defined parameters from query module. Defaults to {}.
hidden_features_sizeList[Int][16, 16]Embedding dimension for each node in a new layer.
layer_typeStringGATJKType of layer used, supported types: GATJK, GAT, GRAPHSAGE.
aggregatorStringmeanType of aggregator used, supported type: mean.
learning_rateFloat0.1Optimizer's learning rate.
weight_decayFloat5e-4Optimizer's weight decay.
split_ratioFloat0.8Ratio between training and validation data.
metricsList[String]["loss","accuracy","f1_score","precision","recall","num_wrong_examples"]List of metrics to report, supports any combination of "loss","accuracy","f1_score","precision","recall","num_wrong_examples".
node_id_propertyStringidProperty name of node features.
num_epochsInteger100The number of epochs for model training.
console_log_freqInteger5Specifies how often results will be printed.
checkpoint_freqInteger5Specifies how often the model will be saved. The model is persisted on disc.
device_typeStringcpuDefines if the model will be trained using the cpu or cuda. To run on Cuda GPU, check if the system supports it with torch.cuda.is_available(), then set this flag to cuda.
path_to_modelString"/tmp/torch_models"Path for loading and storing the model.


  • Exception: Exception is raised if some variable in dictionary params is not correctly defined.


  • mgp.Record( hidden_features_size=list, layer_type=str, aggregator=str, learning_rate=float, weight_decay=float, split_ratio=float, metrics=mgp.Any, node_id_property=str, num_epochs=int, console_log_freq=int, checkpoint_freq=int, device_type=str, path_to_model=str, ) ➑ Map of parameters set for training


  CALL node_classification.set_model_parameters(
{layer_type: "GATJK", learning_rate: 0.001, hidden_features_size: [16,16], class_name: "fraud", features_name: "embedding"}


This procedure performs model training. Firstly it declares data, model, optimizer, and criterion. Afterward, it performs training.


  • num_epochs (int, optional) ➑ Number of epochs (default:100).


  • Exception➑ Raised if graph is empty.


  • epoch: int ➑ Epoch number.
  • loss: float➑ Loss of model on training data.
  • val_loss: float➑ Loss of model on validation data.
  • train_log: list➑ List of metrics on training data.
  • val_log: list➑ List of metrics on validation data.


  CALL node_classification.train() YIELD * RETURN *;


Use following procedure to get logged data from training.

Return values​

  • epoch: int ➑ Epoch number for current record's logged data.
  • loss: float➑ Loss in epoch.
  • train_log: mgp.Any ➑ Training parameters for epoch.
  • val_log: mgp.Any➑ Validation parameters for epoch.


  CALL node_classification.get_training_data() YIELD * RETURN *;


This function saves the model to a specified folder. If there are already max_models_to_keep in the folder, the oldest model is deleted.


  • Exception: Raised if model is not initialized or defined.

Return values​

  • path (str)➑ Path to the stored model.
  • status (str)➑ Status of the stored model.


  CALL node_classification.save_model() YIELD * RETURN *;


This function loads the model from the specified folder.


  • num (int, optional): Ordinal number of model to load from the default path on the disc (default: 0, i.e., newest model).

Return values​

  • path: str ➑ Path of loaded model.


  CALL node_classification.load_model() YIELD * RETURN *;


This function predicts metrics on one node. It is suggested to load the test data (data without labels) as well. Test data won't be a part of the training or validation process.


  • vertex: mgp.Vertex➑ Prediction node.

Return values​

  • predicted_class: int➑ Predicted class for specified node.


MATCH (n {id: 1}) CALL node_classification.predict(n) YIELD * RETURN predicted_value;


This function resets all variables to default values.

Return values​

  • status (str): Status of reset function.


  CALL node_classification.reset() YIELD * RETURN *;