temporal_graph_networks
The temporal_graph_networks (TGNs) are a type of graph neural network (GNN) (opens in a new tab) for dynamic graphs. In recent years, GNNs have become very popular due to their ability to perform a wide variety of machine learning tasks on graphs, such as link prediction, node classification, and so on. This rise started with Graph convolutional networks (GCN) (opens in a new tab) introduced by Kipf et al., followed by GraphSAGE (opens in a new tab) introduced by Hamilton et al., and recently a new method that introduces the attention mechanism to graphs was presented - Graph attention networks (GAT) (opens in a new tab), by Veličković et al.
The last two methods offer a great possibility for inductive learning. But they haven't been specifically developed to handle different events occurring on graphs, such as node features updates, node deletion, edge deletion and so on. These events happen regularly in real-world examples such as the Twitter network (opens in a new tab), where users update their profile, delete their profile or just unfollow another user.
In their work, Rossi et al. introduce Temporal graph networks (opens in a new tab), an architecture for machine learning on streamed graphs, a rapidly-growing ML use case.
The module consists ofthe following features as introduced by Rossi et al. (opens in a new tab):
- Link prediction - Train your TGN to predict new links/edges and node classification - predict labels of nodes from graph structure and node/edge features.
- Graph attention layer embedding calculation and graph sum layer embedding layer calculation.
- Mean and last as message aggregators,
- mlp and identity(concatenation) as message functions,
- gru and rnn as memory updater,
- Uniform temporal neighborhood sampler,
- Memory store and raw message store.
This means you can use TGN to predict edges or perform node classification tasks, with graph attention layer or graph sum layer, by using either mean or last as message aggregator, mlp or identity as message function, and finally gru or rnn as memory updater.
In total, this gives you 2✕2✕2✕2✕2 options, that is, 32 options to explore on your graph!
What is not implemented in the module:
- Node update/deletion events since they occur very rarely - although we have prepared a codebase to easily integrate them.
- Edge deletion events.
- Time projection embedding calculation and identity embedding calculation since the author mentions they perform very poorly on all datasets (although it is trivial to add a new layer).
The query module is implemented using PyTorch (opens in a new tab). From
the input (mgp.Edge
and mgp.Vertex
labels), edge features
and node features
are extracted. With a trigger set, the update
query module procedure
will parse all new edges and extract the information the TGN needs to do
batch by batch processing.
On the following piece of code, you can see what is extracted from edges while
the batch is filling up. When the current processing batch size reaches
batch size
(predefined in set()
), we forward the extracted information
to the TGN, which extends torch.nn.Module
.
@dataclasses.dataclass
class QueryModuleTGNBatch:
current_batch_size: int
sources: np.array
destinations: np.array
timestamps: np.array
edge_idxs: np.array
node_features: Dict[int, torch.Tensor]
edge_features: Dict[int, torch.Tensor]
batch_size: int
labels: np.array
Processing one batch
self._process_previous_batches()
graph_data = self._get_graph_data(
np.concatenate([sources.copy(), destinations.copy()], dtype=int),
np.concatenate([timestamps, timestamps]),
)
embeddings = self.tgn_net(graph_data)
... process negative edges in a similar way
self._process_current_batch(
sources, destinations, node_features, edge_features, edge_idxs, timestamps
)
The torch.nn.Module
is organized as follows:
- Processing previous batches - as in the research paper (opens in a new tab) this will include new computation of messages collected for each node with a message function, aggregation of messages for each node with a message aggregator and finally updating of each node's memory with a memory updater.
- Afterward, create a computation graph used by the graph attention layer or graph sum layer.
- The final step includes processing the current batch, creating new interaction or node events, and updating the raw message store with new events
The process repeats: as we get new edges in a batch, the batch fills, and the new edges are forwarded to the TGN and so on.
Trait | Value |
---|---|
Module type | module |
Implementation | Python |
Graph direction | directed |
Edge weights | weighted/unweighted |
Parallelism | sequential |
Usage
The basic TGN workflow is as follows:
- Set parameters by calling the
set_params()
procedure. - Set a trigger on edge create event to call the
update()
procedure. - Start loading your train queries.
- When train queries are loaded, switch the TGN mode to
eval
by calling theset_eval()
procedure. - Load the evaluation queries.
- Do a few more epochs of training and evaluation to get the best results by
calling the
train_and_eval()
procedure.
By calling set_eval()
function you change the mode of temporal graph
networks to eval
mode. Any new edges which arrive will not be used to
train the module, but to evaluate.
Procedures
This MAGE module is still in its early stage. We intend to use it only for
learning activities. The current state of the module is that you need to
manually switch the TGN mode to eval
. After the switch, incoming edges will be
used for evaluation only. If you wish to make it production-ready, make sure
to either open a GitHub issue (opens in a new tab) or
drop us a comment on Discord (opens in a new tab).
You can execute this algorithm on graph projections, subgraphs or portions of the graph.
set_params()
There are defined default
value for each of the parameters. If you wish to
change any of them, call the set_params()
procedure with the defined new
values.
Input:
-
subgraph: Graph
(OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by theproject()
function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default. -
params: mgp.Map
➡ A dictionary containing the following parameters:
Name | Type | Default | Description |
---|---|---|---|
learning_type | String | "self_supervised" | self_supervised or supervised depending on if you want to predict edges or node labels. |
batch_size | Integer | 200 | Size of the batch to process by the TGN, recommended size 200. |
num_of_layers | Integer | 2 | The number of layers of a graph neural network, 2 is the optimal size, GNNs perform worse with more layers in terms of time needed to train, and the gain in accuracy is insignificant. |
layer_type | String | graph_attn | graph_attn or graph_sum layer type as defined in the original paper. |
memory_dimension | Integer | 100 | The dimension of memory tensor of each node. |
time_dimension | Integer | 100 | The dimension of time vector from the time2vec paper. |
num_edge_features | Integer | 50 | The number of edge features used from each edge. |
num_node_features | Integer | 50 | The number of expected node features. |
message_dimension | Integer | 100 | The dimension of the message, only used if you use MLP as the message function type, otherwise ignored. |
num_neighbors | Integer | 15 | The number of sampled neighbors. |
edge_message_function_type | String | identity | The message function type, identity for concatenation or mlp for projection. |
message_aggregator_type | String | last | The message aggregator type, mean or last . |
memory_updater_type | String | gru | The memory updater type, gru or rnn . |
num_attention_heads | Integer | 1 | The number of attention heads used if you define graph_attn as layer type. |
learning_rate | Float | 1e-4 | The learning rate for adam optimizer. |
weight_decay | Float | 5e-5 | The weight decay used in adam optimizer. |
device_type | String | cuda | The type of device you want to use for training - cuda or cpu . |
node_features_property | String | features | The name of the features property on nodes from which to read features. |
edge_features_property | String | features | The name of the features property on edges from which to read features. |
node_label_property | String | label | The name of the label property on nodes from which to read features. |
Usage:
To set parameters, use the following query:
CALL tgn.set_params({learning_type:'self_supervised', batch_size:200, num_of_layers:2,
layer_type:'graph_attn',memory_dimension:20, time_dimension:50,
num_edge_features:20, num_node_features:20, message_dimension:100,
num_neighbors:15, edge_message_function_type:'identity',
message_aggregator_type:'last', memory_updater_type:'gru', num_attention_heads:1});
update()
The update()
procedure scrapes data from edges, including edge_features
and
node_features
if they exist, and fills up the batch. If the batch is ready, the
TGN will process it and be ready to accept new incoming edges.
Input:
-
subgraph: Graph
(OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by theproject()
function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default. -
edges: mgp.List[mgp.Edges]
➡ List of edges to preprocess (that arrive in a stream to Memgraph). If a batch is full,train
oreval
starts, depending on the mode.
Usage:
There are two options how to use the procedure.
The most convenient one is to create a trigger so that every time an edge is added to the graph, the trigger calls the procedure and makes an update.
CREATE TRIGGER create_embeddings ON --> CREATE BEFORE COMMIT
EXECUTE CALL tgn.update(createdEdges) RETURN 1;
The second option is to add all the edges and then call the algorithm with them:
MATCH (n)-[e]->(m)
WITH COLLECT(e) as edges
CALL tgn.update(edges) RETURN 1;
get()
The get()
procedure retrieves calculated embeddings for each vertex.
Input:
subgraph: Graph
(OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by theproject()
function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.
Output:
node: mgp.Vertex
➡ Vertex (node) in Memgraph.embedding: mgp.List[float]
➡ Low-dimensional representation of node in the form of a graph embedding.
Usage:
To get calculated embeddings, run the following query:
CALL tgn.get()
YIELD node, embedding
RETURN node, embedding;
set_eval()
Use the set_eval()
procedure to change the TGN mode to eval
.
Input:
subgraph: Graph
(OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by theproject()
function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.
Output:
message: string
➡ Message that the TGN mode has been changed toeval
.
Usage:
CALL tgn.set_eval()
YIELD message
RETURN message;
get_results()
The get_results()
procedure will return results
for every batch you ran
train
or eval
on, as well as average_precision
, and batch_process_time
.
Epoch count starts from 1.
Input:
subgraph: Graph
(OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by theproject()
function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default.
Output:
epoch_num:mgp.Number
➡ The number oftrain
oreval
epochs.batch_num:mgp.Number
➡ The number of batches pertrain
oreval
epoch.batch_process_time:mgp.Number
➡ Time needed to process a batch.average_precision:mgp.Number
➡ Mean average precision on the current batch.batch_type:string
➡ A string indicating whethertrain
oreval
is performed on the batch.
Usage:
To get resuts, use the following query:
CALL tgn.get_results()
YIELD epoch_num, batch_num, batch_process_time, average_precision, batch_type
RETURN epoch_num, batch_num, batch_process_time, average_precision, batch_type;
train_and_eval()
The purpose of the train_and_eval()
procedure is to do additional training
rounds on train
edges and eval
on evaluation edges.
Input:
-
subgraph: Graph
(OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by theproject()
function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default. -
num_epochs: integer
➡ Perform additional epoch training and evaluation after the stream is done.
Output:
epoch_num: integer
➡ The epoch of the batch for which performance statistics will be returned.batch_num: integer
➡ The number of the batch for which performance statistics will be returned.batch_process_time: float
➡ Processing time in seconds for a batch.average_precision:mgp.Number
➡ Mean average precision on the current batch.batch_type:string
➡ Whethertrain
oreval
was performed on the batch.
Usage:
To do additional training rounds, use the following query:
CALL tgn.train_and_eval(10)
YIELD epoch_num, batch_num, batch_process_time, average_precision, batch_type
RETURN epoch_num, batch_num, batch_process_time, average_precision, batch_type ;
predict_link_score()
The purpose of the predict_link_score()
procedure is to get the link
prediction score for two vertices in graph if you have been training TGN
for
the link prediction task.
Input:
-
subgraph: Graph
(OPTIONAL) ➡ A specific subgraph, which is an object of type Graph returned by theproject()
function, on which the algorithm is run. If subgraph is not specified, the algorithm is computed on the entire graph by default. -
src: mgp.Vertex
➡ Source vertex of the link prediction. -
dest: mgp.Vertex
➡ Destination vertex of the link prediction.
Output:
prediction: mgp.Number
➡ Float number between 0 and 1, likelihood of link betweensource
vertex anddestination
vertex.
Usage:
To get a link prediction between two nodes, use the following query:
MATCH (n:User)
WITH n
LIMIT 1
MATCH (m:Item)
OPTIONAL MATCH (n)-[r]->(m)
WHERE r is null
CALL tgn.predict_link_score(n,m)
YIELD prediction
RETURN n,m, prediction;
Example
Database state
The database contains the following data:
Set model parameters
To set the model parameters, use the following query:
CALL tgn.set_params({learning_type:'self_supervised', batch_size:2, num_of_layers:1,
layer_type:'graph_attn',memory_dimension:100, time_dimension:100,
num_edge_features:20, num_node_features:20, message_dimension:100,
num_neighbors:10, edge_message_function_type:'identity',
message_aggregator_type:'last', memory_updater_type:'gru', num_attention_heads:1});
Set the trigger
To set the trigger, use the following query:
CREATE TRIGGER create_embeddings ON --> CREATE BEFORE COMMIT
EXECUTE CALL tgn.update(createdEdges) RETURN 1;
Load the training batch
Use the following queries to load the training batch:
MERGE (n:Node {id: 1}) MERGE (m:Node {id: 6}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 2}) MERGE (m:Node {id: 6}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 10}) MERGE (m:Node {id: 5}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 5}) MERGE (m:Node {id: 2}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 9}) MERGE (m:Node {id: 7}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 7}) MERGE (m:Node {id: 3}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 3}) MERGE (m:Node {id: 6}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 9}) MERGE (m:Node {id: 8}) CREATE (n)-[:RELATION]->(m);
Change the TGN mode
To change the TGN mode, use the following query:
CALL tgn.set_eval()
YIELD message
RETURN message;
Load the evaluation batch
To load the evaluation batch, use the following query:
MERGE (n:Node {id: 8}) MERGE (m:Node {id: 4}) CREATE (n)-[:RELATION]->(m);
MERGE (n:Node {id: 4}) MERGE (m:Node {id: 6}) CREATE (n)-[:RELATION]->(m);
Do additional training
To do additional training, use the following query:
CALL tgn.train_and_eval(5)
YIELD epoch_num, batch_num, batch_process_time, average_precision, batch_type
RETURN epoch_num, batch_num, batch_process_time, average_precision, batch_type ;
Get results
To get the results
for every batch you ran train
or eval
on, run the following query:
CALL tgn.get_results() YIELD epoch_num, batch_num, average_precision, batch_process_time, batch_type
RETURN epoch_num, batch_num, average_precision, batch_type, batch_process_time;
Results:
+--------------------+--------------------+--------------------+--------------------+--------------------+
| epoch_num | batch_num | average_precision | batch_type | batch_process_time |
+--------------------+--------------------+--------------------+--------------------+--------------------+
| 1 | 1 | 0.5 | "Train" | 0.05 |
| 1 | 2 | 0.42 | "Eval" | 0.02 |
| 2 | 1 | 0.83 | "Train" | 0.03 |
| 2 | 2 | 0.5 | "Train" | 0.04 |
| 2 | 3 | 0.5 | "Train" | 0.04 |
| 2 | 4 | 0.58 | "Train" | 0.04 |
| 2 | 5 | 0.83 | "Eval" | 0.02 |
| 3 | 1 | 0.5 | "Train" | 0.03 |
| 3 | 2 | 0.75 | "Train" | 0.03 |
| 3 | 3 | 0.83 | "Train" | 0.03 |
| 3 | 4 | 1 | "Train" | 0.04 |
| 3 | 5 | 0.83 | "Eval" | 0.02 |
| 4 | 1 | 0.5 | "Train" | 0.03 |
| 4 | 2 | 0.58 | "Train" | 0.03 |
| 4 | 3 | 1 | "Train" | 0.03 |
| 4 | 4 | 1 | "Train" | 0.04 |
| 4 | 5 | 1 | "Eval" | 0.02 |
| 5 | 1 | 0.83 | "Train" | 0.03 |
| 5 | 2 | 0.58 | "Train" | 0.03 |
| 5 | 3 | 1 | "Train" | 0.03 |
| 5 | 4 | 1 | "Train" | 0.03 |
| 5 | 5 | 0.83 | "Eval" | 0.02 |
| 6 | 1 | 0.58 | "Train" | 0.03 |
| 6 | 2 | 0.83 | "Train" | 0.03 |
| 6 | 3 | 1 | "Train" | 0.03 |
| 6 | 4 | 1 | "Train" | 0.03 |
| 6 | 5 | 1 | "Eval" | 0.01 |
+--------------------+--------------------+--------------------+--------------------+--------------------+