Memgraph logo
Back to blog
A Hyperparametrization Is All You Need - Building a Recommendation System for Telecommunication Packages Using Graph Neural Networks

A Hyperparametrization Is All You Need - Building a Recommendation System for Telecommunication Packages Using Graph Neural Networks

December 5, 2022
Andi Skrgat

Hello folks! We are happy to announce that Memgraph prepared another magic spell to continue the graph machine learning story. After successfully entering the world of graph machine learning with Temporal Graph Networks, it is time to bring more great graph neural network (GNN) models to Memgraph.

Read on to find out:

  • Why are GNNs often a part of headlines of research papers in 2022?
  • How to build a link prediction module?
  • How to build a telecommunication packages recommendation system from a link prediction?

Let’s dive into the magic world of GNNs!

The incredible world of GNNs

Since its early introduction, graph neural networks have proven to be extremely useful in numerous domains and applications. They are a part of a much broader area called geometric deep learning that deals not only with graphs but with ubiquitous manifolds as well.

Its expressiveness and the ability to encode many complicated relationships offer to the graph world what convolutional neural networks (CNNs) offer to the computer vision and NLP world. One could then ask, why do we even need GNNs? The simplest answer is that, although text and images can be thought of as graphs, they show much more spatial locality than graphs do, so highly regularized neural networks like CNNs aren’t suitable for graphs.

Hmm, this all seems a little bit abstract, so what is the goal of graph neural networks? The goal is to get the node representations automatically and efficiently by iteratively aggregating the representations of node neighbors and combining them with their representation from the previous iteration. A good representation is useful and contains discriminative information captured from several explanatory factors. General properties such as smoothness, disentanglement and linearity are also very desirable. This is the basis of many great modules, such as GraphSAGE and GAT.

GraphSAGE is an incredible algorithm developed at Stanford that enables models to scale up to giant graphs. Its strength lies in the ability to efficiently sample and then aggregate information from node neighborhoods.

image alt

Figure 1. GraphSAGE modus operandi.

The model starts by aggregating representations from neighbors:

image alt

Then, it combines them with the old node representation:

image alt

And at the end normalizes them to get more stable learning process:

[image alt](https:// "title"

V is a set of nodes, N(v) specifies the neighborhood of node v, subscript determines node and superscript iteration.

Graph Attention Network (GAT) gives almost unattainable performance by using an attention mechanism that allows weighting neighbor importance to every node in the graph. To achieve that, a weight matrix W is needed as a linear transformation for learning node representations. With that in mind, attention coefficients can be computed for every pair of nodes that are neighbors using the attention function:

image alt

where the coefficient indicates the importance of node j to node i. To be able to use those coefficients, one needs to normalize them using the softmax function.

image alt

After applying the attention function, one can simply aggregate representations from neighbors using the same equations as in the GraphSAGE case.

Both models and the graph creation were developed using the DGL implementation. DGL is an open-source project, hosted on GitHub, constantly improving thanks to numerous active developers working on it. One of the greatest things that this library offers is the framework-agnostic setup, which allows you to use whatever backend you want, once the graph has been created. Whether you want to use PyTorch, MXNet or TensorFlow, it's your call 🖖.

Why is link prediction so special?

Okay, it was mentioned above that GNNs are extremely good at creating node embeddings by aggregating information from local neighborhoods but how can one create a model that will predict links between nodes? Well, one more step is needed in the pipeline and that is edge predictor. The link prediction module supports DotPredictor, which, as its name suggests, computes the edge score as the dot product of the source and the destination node embedding, and an MLPPredictor, a trainable neural network that again takes as input source and destination node embedding and as the output of the last layer, produces an edge score. Voilà, now if you add a simple classifier on the top of edge scores you can start playing and building amazing products 🔨.

image alt

Figure 2. Combing node embeddings to edge score using DotPredictor.

In the graph machine learning, transductive and inductive terms are often mentioned but many people struggle with understanding the difference. In the simplest possible way, the transductive mode doesn’t allow you to generalize on nodes the model hasn’t seen, while inductive does, and at the inference time, you can even plug a completely new graph.

The new MAGE module does allow you to generalize on unseen nodes but not on unseen graphs because of the specific use case we covered. If you need this feature, feel free to open an issue on GitHub or write a PR to contribute to Memgraph, why not?

The world of telecommunications

After a quite comprehensive theoretical introduction, let’s get to our domain and find out why graphs are a great tool to model a use case in telecommunications.

In this tutorial, we will use a dataset IBM uses for its solutions. It consists of 7042 customers that interact with services (packages) for which they have a contract.

image alt

Figure 3. A model graph and list of all services supported.

Since our goal is to create a system that will recommend packages to new users, it makes sense to connect customers with packages to capture this interaction. However, personal information about customers, such as, their age, gender, living location, marriage status and if some dependents are living with them (children, parents) is also provided, so to extract more information, customers that share age or living location, have been connected. Note that making such connections enhances the learning process because now, customers can update representations of other customer neighbors while before customer neighbors could be updated only from packages they are subscribed to.

Building a recommendation system

Once we have created a graph, we can extract an abundance of information because of the complicated relationships we were able to encode. Before diving deeper into the implementation, install the Memgraph platform and make sure to check out how LOAD CSV clause works.

Getting the data

First, download the CSV file that specifies service features. Then, use Memgraph Lab to import the data from the downloaded CSV file.

LOAD CSV FROM "./services.csv" WITH HEADER AS row
CREATE (s:Service {name: row["service_name"], features: row["features"]})
RETURN s;

Although it might seem unintuitive, it is interesting that by setting features to be random vectors, GNNs become universal approximators, which is very cool!

Now, using the second CSV file, we will connect customers to packages they are subscribed to:

LOAD CSV FROM "./interactions.csv" WITH HEADER AS row
CREATE (c:Customer {id: row["Customer ID"], gender: row["Gender"], age: row["Age"], married: row["Married"], dependents: row["Dependents"], number_of_dependents: row["Number of Dependents"], city: row["City"], features: row["Features"], services: row["Services"]})
WITH c
MATCH (s:Service)
WHERE c.services CONTAINS s.name
CREATE (c)-[r:SUBSCRIBES_TO]->(s)
RETURN c, r, s;

Before connecting customers, it is desirable to create indexes to speed up the loading process:

CREATE INDEX ON :Customer(id);
CREATE INDEX ON :Customer(age);
CREATE INDEX ON :Customer(city);

And now, for reasons mentioned above, customers can be connected based on their age and city:

MATCH (c1:Customer)
MATCH (c2:Customer)
WHERE (c1.id != c2.id) AND ((c1.age = c2.age) AND (c1.city = c2.city))
MERGE (c1)-[e:CONNECTS_TO]-(c2)
RETURN c1, e, c2;

Training the model

Data has been loaded, so the only thing that remains to be done before actually training the module is setting the training parameters. This is an important step because, as you might already know, every problem has its own set of best parameters and as no free lunch theorem states, the best parameters from one use case aren’t necessarily the best for some other. An important thing to consider is also the performance you want to achieve: more layers result in a slower training process. For the graph neural networks, three or four layers should be enough to handle almost any use case.

There is a whole bunch of parameters you can set in the link prediction module and you can take a look at the docs page for more details. Here’s a brief mention of those parameters you need to modify when developing your solution in this use case:

  • target_relation - A heterogeneous graph has multiple edge types so you need to specify on which edge type you want to train your model. In the case of a homogenous graph, target_relation will be automatically inferred.
  • node_features_property - Property name where node features are saved.
  • add_reverse_edges - To be as efficient as possible, store a directed graph in MemgraphDB and set this flag to true so the model can learn better representations by adding reverse edges to each existing edge.

We can use the following query to set parameters:

CALL link_prediction.set_model_parameters({target_relation: “SUBSCRIBED_TO”, node_features_property: “features”, add_reverse_edges: True}) YIELD * RETURN *;

And for training use this query:

CALL link_prediction.train() YIELD * RETURN *;

A great way to inspect a machine learning model is to visualize the learning process:

image alt

Figure 4. Performance visualization.

image alt

Figure 5. Metrics summary

Explaining machine learning results always starts with computing a confusion matrix. In our recommendation system, a false negative example occurs when the model states there isn’t a relationship between two nodes when in fact there is. On the other hand, when the model predicts there is an edge but in reality, there isn’t one, false positive we salute you ✋. True positives and true negatives occur when the model correctly predicts the existence or non-existence of an edge. Armed with all confusion matrix combatants, we are ready to calculate accuracy, precision and recall.

Accuracy and F1 of 90% show that the GNN successfully found patterns in the dataset. It also leaves us space for improvement and there are many things yet to be tried but considering the fact that the dataset isn’t meant to be used for recommendations and in the lack of real features, we randomly initialized them, we can conclude that GNNs are really a heck of a deal 🥇.

Just as a side note, be aware that we managed to obtain such results after 5 hours of training (without GPU) and trying more than a hundred combinations with random search. That’s why the “Hyperparametrization is all you need” title 🧮.

Create recommendations

With the trained model it’s extremely easy to get new recommendations for users:

MATCH (n:Customer {id: "1658-BYGOY"})
MATCH (s:Service)
WITH collect(s) AS services, n
CALL link_prediction.recommended_vertex(n, services, 6)
YIELD score, recommendation
RETURN score, recommendation;

The model recommended the following packages to the user:

image alt

Figure 6. Recommended packages.

With this query, the user also gets recommendation-specific metrics like Precision@k, Recall@k, F1@k and Average precision.

image alt

Figure 7. Recommendation metrics.

Conclusion

We hope that this post was useful and inspiring. If you like what you just read, think about giving us a ⭐ at our MAGE repository. If you have any questions feel free to start a discussion on Discord or open an issue on GitHub.

See you soon with the next 🪄 spell.

References

Hamilton, William L., et al. “Inductive Representation Learning on Large Graphs.” Stanford Computer Science, Accessed 29 August 2022.

Veličković, Petar, et al. “1710.10903v3 stat.ML 4 Feb 2018.” arXiv, 4 February 2018, Accessed 29 August 2022.

Join us on Discord!
Find other developers performing graph analytics in real time with Memgraph.
© 2024 Memgraph Ltd. All rights reserved.