temporal_graph_networks

The temporal_graph_networks (TGNs) are a type of graph neural network (GNN) for dynamic graphs. In recent years, GNNs have become very popular due to their ability to perform a wide variety of machine learning tasks on graphs, such as link prediction, node classification, and so on. This rise started with Graph convolutional networks (GCN) introduced by Kipf et al., followed by GraphSAGE introduced by Hamilton et al., and recently a new method that introduces the attention mechanism to graphs was presented - Graph attention networks (GAT), by Veličković et al.

The last two methods offer a great possibility for inductive learning. But they haven’t been specifically developed to handle different events occurring on graphs, such as node features updates, node deletion, edge deletion and so on. These events happen regularly in real-world examples such as the Twitter network, where users update their profile, delete their profile or just unfollow another user.

In their work, Rossi et al. introduce Temporal graph networks, an architecture for machine learning on streamed graphs, a rapidly-growing ML use case.

The module consists ofthe following features as introduced by Rossi et al.:

  • Link prediction - Train your TGN to predict new links/edges and node classification - predict labels of nodes from graph structure and node/edge features.
  • Graph attention layer embedding calculation and graph sum layer embedding layer calculation.
  • Mean and last as message aggregators,
  • mlp and identity(concatenation) as message functions,
  • gru and rnn as memory updater,
  • Uniform temporal neighborhood sampler,
  • Memory store and raw message store.

This means you can use TGN to predict edges or perform node classification tasks, with graph attention layer or graph sum layer, by using either mean or last as message aggregator, mlp or identity as message function, and finally gru or rnn as memory updater.

In total, this gives you 2✕2✕2✕2✕2 options, that is, 32 options to explore on your graph!

What is not implemented in the module:

  • Node update/deletion events since they occur very rarely - although we have prepared a codebase to easily integrate them.
  • Edge deletion events.
  • Time projection embedding calculation and identity embedding calculation since the author mentions they perform very poorly on all datasets (although it is trivial to add a new layer).

The query module is implemented using PyTorch. From the input (mgp.Edge and mgp.Vertex labels), edge features and node features are extracted. With a trigger set, the update query module procedure will parse all new edges and extract the information the TGN needs to do batch by batch processing.

On the following piece of code, you can see what is extracted from edges while the batch is filling up. When the current processing batch size reaches batch size (predefined in set()), we forward the extracted information to the TGN, which extends torch.nn.Module.

@dataclasses.dataclass
class QueryModuleTGNBatch:
    current_batch_size: int
    sources: np.array
    destinations: np.array
    timestamps: np.array
    edge_idxs: np.array
    node_features: Dict[int, torch.Tensor]
    edge_features: Dict[int, torch.Tensor]
    batch_size: int
    labels: np.array

Processing one batch

        self._process_previous_batches()
 
        graph_data = self._get_graph_data(
            np.concatenate([sources.copy(), destinations.copy()], dtype=int),
            np.concatenate([timestamps, timestamps]),
        )
 
        embeddings = self.tgn_net(graph_data)
 
        ... process negative edges in a similar way
 
        self._process_current_batch(
            sources, destinations, node_features, edge_features, edge_idxs, timestamps
        )

The torch.nn.Module is organized as follows:

  • Processing previous batches - as in the research paper this will include new computation of messages collected for each node with a message function, aggregation of messages for each node with a message aggregator and finally updating of each node’s memory with a memory updater.
  • Afterward, create a computation graph used by the graph attention layer or graph sum layer.
  • The final step includes processing the current batch, creating new interaction or node events, and updating the raw message store with new events

The process repeats: as we get new edges in a batch, the batch fills, and the new edges are forwarded to the TGN and so on.

TraitValue
Module typemodule
ImplementationPython
Graph directiondirected
Edge weightsweighted/unweighted
Parallelismsequential

Usage

The basic TGN workflow is as follows:

  1. Set parameters by calling the set_params() procedure.
  2. Set a trigger on edge create event to call the update() procedure.
  3. Start loading your train queries.
  4. When train queries are loaded, switch the TGN mode to eval by calling the set_eval() procedure.
  5. Load the evaluation queries.
  6. Do a few more epochs of training and evaluation to get the best results by calling the train_and_eval() procedure.

By calling set_eval() function you change the mode of temporal graph networks to eval mode. Any new edges which arrive will not be used to train the module, but to evaluate.

Procedures

This MAGE module is still in its early stage. We intend to use it only for learning activities. The current state of the module is that you need to manually switch the TGN mode to eval. After the switch, incoming edges will be used for evaluation only. If you wish to make it production-ready, make sure to either open a GitHub issue or drop us a comment on Discord.

set_params()

There are defined default value for each of the parameters. If you wish to change any of them, call the set_params() procedure with the defined new values.

Input:

  • subgraph: Graph (OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by the project() function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.

  • params: mgp.Map ➡ A dictionary containing the following parameters:

NameTypeDefaultDescription
learning_typeString"self_supervised"self_supervised or supervised depending on if you want to predict edges or node labels.
batch_sizeInteger200Size of the batch to process by the TGN, recommended size 200.
num_of_layersInteger2The number of layers of a graph neural network, 2 is the optimal size, GNNs perform worse with more layers in terms of time needed to train, and the gain in accuracy is insignificant.
layer_typeStringgraph_attngraph_attn or graph_sum layer type as defined in the original paper.
memory_dimensionInteger100The dimension of memory tensor of each node.
time_dimensionInteger100The dimension of time vector from the time2vec paper.
num_edge_featuresInteger50The number of edge features used from each edge.
num_node_featuresInteger50The number of expected node features.
message_dimensionInteger100The dimension of the message, only used if you use MLP as the message function type, otherwise ignored.
num_neighborsInteger15The number of sampled neighbors.
edge_message_function_typeStringidentityThe message function type, identity for concatenation or mlp for projection.
message_aggregator_typeStringlastThe message aggregator type, mean or last.
memory_updater_typeStringgruThe memory updater type, gru or rnn.
num_attention_headsInteger1The number of attention heads used if you define graph_attn as layer type.
learning_rateFloat1e-4The learning rate for adam optimizer.
weight_decayFloat5e-5The weight decay used in adam optimizer.
device_typeStringcudaThe type of device you want to use for training - cuda or cpu.
node_features_propertyStringfeaturesThe name of the features property on nodes from which to read features.
edge_features_propertyStringfeaturesThe name of the features property on edges from which to read features.
node_label_propertyStringlabelThe name of the label property on nodes from which to read features.

Usage:

To set parameters, use the following query:

 CALL tgn.set_params({learning_type:'self_supervised', batch_size:200, num_of_layers:2,
                      layer_type:'graph_attn',memory_dimension:20, time_dimension:50,
                      num_edge_features:20, num_node_features:20, message_dimension:100,
                      num_neighbors:15, edge_message_function_type:'identity',
                      message_aggregator_type:'last', memory_updater_type:'gru', num_attention_heads:1});

update()

The update() procedure scrapes data from edges, including edge_features and node_features if they exist, and fills up the batch. If the batch is ready, the TGN will process it and be ready to accept new incoming edges.

Input:

  • subgraph: Graph (OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by the project() function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.

  • edges: mgp.List[mgp.Edges] ➡ List of edges to preprocess (that arrive in a stream to Memgraph). If a batch is full, train or eval starts, depending on the mode.

Usage:

There are two options how to use the procedure.

The most convenient one is to create a trigger so that every time an edge is added to the graph, the trigger calls the procedure and makes an update.

CREATE TRIGGER create_embeddings ON --> CREATE BEFORE COMMIT
EXECUTE CALL tgn.update(createdEdges) RETURN 1;

The second option is to add all the edges and then call the algorithm with them:

MATCH (n)-[e]->(m)
WITH COLLECT(e) as edges
CALL tgn.update(edges) RETURN 1;

get()

The get() procedure retrieves calculated embeddings for each vertex.

Input:

  • subgraph: Graph (OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by the project() function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.

Output:

  • node: mgp.Vertex ➡ Vertex (node) in Memgraph.
  • embedding: mgp.List[float] ➡ Low-dimensional representation of node in the form of a graph embedding.

Usage:

To get calculated embeddings, run the following query:

CALL tgn.get()
YIELD node, embedding 
RETURN node, embedding;

set_eval()

Use the set_eval() procedure to change the TGN mode to eval.

Input:

  • subgraph: Graph (OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by the project() function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.

Output:

  • message: string ➡ Message that the TGN mode has been changed to eval.

Usage:

CALL tgn.set_eval() 
YIELD message
RETURN message;

get_results()

The get_results() procedure will return results for every batch you ran train or eval on, as well as average_precision, and batch_process_time. Epoch count starts from 1.

Input:

  • subgraph: Graph (OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by the project() function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.

Output:

  • epoch_num:mgp.Number ➡ The number of train or eval epochs.
  • batch_num:mgp.Number ➡ The number of batches per train or eval epoch.
  • batch_process_time:mgp.Number ➡ Time needed to process a batch.
  • average_precision:mgp.Number ➡ Mean average precision on the current batch.
  • batch_type:string ➡ A string indicating whether train or eval is performed on the batch.

Usage:

To get resuts, use the following query:

CALL tgn.get_results() 
YIELD epoch_num, batch_num, batch_process_time, average_precision, batch_type 
RETURN epoch_num, batch_num, batch_process_time, average_precision, batch_type;

train_and_eval()

The purpose of the train_and_eval() procedure is to do additional training rounds on train edges and eval on evaluation edges.

Input:

  • subgraph: Graph (OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by the project() function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.

  • num_epochs: integer ➡ Perform additional epoch training and evaluation after the stream is done.

Output:

  • epoch_num: integer ➡ The epoch of the batch for which performance statistics will be returned.
  • batch_num: integer ➡ The number of the batch for which performance statistics will be returned.
  • batch_process_time: float ➡ Processing time in seconds for a batch.
  • average_precision:mgp.Number ➡ Mean average precision on the current batch.
  • batch_type:string ➡ Whether train or eval was performed on the batch.

Usage:

To do additional training rounds, use the following query:

CALL tgn.train_and_eval(10) 
YIELD epoch_num, batch_num, batch_process_time, average_precision, batch_type  
RETURN epoch_num, batch_num, batch_process_time, average_precision, batch_type ;

The purpose of the predict_link_score() procedure is to get the link prediction score for two vertices in graph if you have been training TGN for the link prediction task.

Input:

  • subgraph: Graph (OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by the project() function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.

  • src: mgp.Vertex ➡ Source vertex of the link prediction.

  • dest: mgp.Vertex ➡ Destination vertex of the link prediction.

Output:

  • prediction: mgp.Number ➡ Float number between 0 and 1, likelihood of link between source vertex and destination vertex.

Usage:

To get a link prediction between two nodes, use the following query:

MATCH (n:User)
WITH n
LIMIT 1
MATCH (m:Item)
OPTIONAL MATCH  (n)-[r]->(m)
  WHERE r is null
CALL tgn.predict_link_score(n,m) 
YIELD prediction
RETURN n,m, prediction;

Example

Database state

The database contains the following data:

Set model parameters

To set the model parameters, use the following query:

CALL tgn.set_params({learning_type:'self_supervised', batch_size:2, num_of_layers:1,
                      layer_type:'graph_attn',memory_dimension:100, time_dimension:100,
                      num_edge_features:20, num_node_features:20, message_dimension:100,
                      num_neighbors:10, edge_message_function_type:'identity',
                      message_aggregator_type:'last', memory_updater_type:'gru', num_attention_heads:1});

Set the trigger

To set the trigger, use the following query:

CREATE TRIGGER create_embeddings ON --> CREATE BEFORE COMMIT
EXECUTE CALL tgn.update(createdEdges) RETURN 1;

Load the training batch

Use the following queries to load the training batch:

MERGE (n:Node {id: 1}) MERGE (m:Node {id: 6}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 2}) MERGE (m:Node {id: 6}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 10}) MERGE (m:Node {id: 5}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 5}) MERGE (m:Node {id: 2}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 9}) MERGE (m:Node {id: 7}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 7}) MERGE (m:Node {id: 3}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 3}) MERGE (m:Node {id: 6}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 9}) MERGE (m:Node {id: 8}) CREATE (n)-[:RELATION]->(m);

Change the TGN mode

To change the TGN mode, use the following query:

CALL tgn.set_eval() 
YIELD message 
RETURN message;

Load the evaluation batch

To load the evaluation batch, use the following query:

MERGE (n:Node {id: 8}) MERGE (m:Node {id: 4}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 4}) MERGE (m:Node {id: 6}) CREATE (n)-[:RELATION]->(m);

Do additional training

To do additional training, use the following query:

CALL tgn.train_and_eval(5) 
YIELD epoch_num, batch_num, batch_process_time, average_precision, batch_type  
RETURN epoch_num, batch_num, batch_process_time, average_precision, batch_type ;

Get results

To get the results for every batch you ran train or eval on, run the following query:

CALL tgn.get_results() YIELD  epoch_num, batch_num, average_precision, batch_process_time, batch_type
RETURN epoch_num, batch_num, average_precision, batch_type, batch_process_time;

Results:

+--------------------+--------------------+--------------------+--------------------+--------------------+
| epoch_num          | batch_num          | average_precision  | batch_type         | batch_process_time |
+--------------------+--------------------+--------------------+--------------------+--------------------+
| 1                  | 1                  | 0.5                | "Train"            | 0.05               |
| 1                  | 2                  | 0.42               | "Eval"             | 0.02               |
| 2                  | 1                  | 0.83               | "Train"            | 0.03               |
| 2                  | 2                  | 0.5                | "Train"            | 0.04               |
| 2                  | 3                  | 0.5                | "Train"            | 0.04               |
| 2                  | 4                  | 0.58               | "Train"            | 0.04               |
| 2                  | 5                  | 0.83               | "Eval"             | 0.02               |
| 3                  | 1                  | 0.5                | "Train"            | 0.03               |
| 3                  | 2                  | 0.75               | "Train"            | 0.03               |
| 3                  | 3                  | 0.83               | "Train"            | 0.03               |
| 3                  | 4                  | 1                  | "Train"            | 0.04               |
| 3                  | 5                  | 0.83               | "Eval"             | 0.02               |
| 4                  | 1                  | 0.5                | "Train"            | 0.03               |
| 4                  | 2                  | 0.58               | "Train"            | 0.03               |
| 4                  | 3                  | 1                  | "Train"            | 0.03               |
| 4                  | 4                  | 1                  | "Train"            | 0.04               |
| 4                  | 5                  | 1                  | "Eval"             | 0.02               |
| 5                  | 1                  | 0.83               | "Train"            | 0.03               |
| 5                  | 2                  | 0.58               | "Train"            | 0.03               |
| 5                  | 3                  | 1                  | "Train"            | 0.03               |
| 5                  | 4                  | 1                  | "Train"            | 0.03               |
| 5                  | 5                  | 0.83               | "Eval"             | 0.02               |
| 6                  | 1                  | 0.58               | "Train"            | 0.03               |
| 6                  | 2                  | 0.83               | "Train"            | 0.03               |
| 6                  | 3                  | 1                  | "Train"            | 0.03               |
| 6                  | 4                  | 1                  | "Train"            | 0.03               |
| 6                  | 5                  | 1                  | "Eval"             | 0.01               |
+--------------------+--------------------+--------------------+--------------------+--------------------+