kmeans

The k-means algorithm clusters given data by trying to separate samples in n groups of equal variance by minimizing the criterion known as within-the-cluster sum-of-squares.

TraitValue
Module typemodule
ImplementationPython
Graph directiondirected/undirected
Edge weightsweighted/unweighted
Parallelismsequential

Too slow?

If this algorithm implementation is too slow for your use case, open an issue on Memgraph's GitHub repository (opens in a new tab) and request a rewrite to C++!

Procedures

get_clusters()

For each node, the get_clusters() procedure returns what cluster it belongs to.

Input:

  • n_clusters : int ➡ Number of clusters to be formed.
  • embedding_property : str (default: "embedding") ➡ Node property where embeddings are stored.
  • init : str (default: "k-means") ➡ Initialization method. If k-means++ is selected, initial cluster centroids are sampled per an empirical probability distribution of the points’ contribution to the overall inertia. This technique speeds up convergence and is theoretically proven to be O(logk)-optimal. If random, n_clusters observations (rows) are randomly chosen for the initial centroids.
  • n_init : int (default: 10) ➡ Number of times the k-means algorithm will be run with different centroid seeds.
  • max_iter : int (default: 10) ➡ Length of sampling walks.
  • tol : float (default: 1e-4) ➡ Relative tolerance of the Frobenius norm of the difference of cluster centers across consecutive iterations. Used in determining convergence.
  • algorithm : str (default: "auto") ➡ Options are lloyd, elkan, auto, full. Description here (opens in a new tab).
  • random_state : int (default: 1998) ➡ Random seed for the algorithm.

Output:

  • node: mgp.Vertex ➡ Graph node.
  • cluster_id: mgp.Number ➡ Cluster ID of the above node.

Usage:

To get clusters, use the following query:

 CALL kmeans.get_clusters(2, "embedding", "k-means++", 10, 10, 0.0001, "auto", 1) YIELD node, cluster_id
 RETURN node.id as node_id, cluster_id
   ORDER BY node_id ASC;

set_clusters()

The set_clusters() procedure sets to which cluster each node it belongs to by writing cluster ID to cluster_property.

Input:

  • n_clusters : int ➡ Number of clusters to be formed.
  • embedding_property : str (default: "embedding") ➡ Node property where embeddings are stored.
  • cluster_property: str (default: "cluster_id") ➡ Node property where cluster_id will be stored.
  • init : str (default: "k-means") ➡ Initialization method. If k-means++ is selected, initial cluster centroids are sampled per an empirical probability distribution of the points’ contribution to the overall inertia. This technique speeds up convergence and is theoretically proven to be O(logk)-optimal. If random, n_clusters observations (nodes) are randomly chosen for the initial centroids.
  • n_init : int (default: 10) ➡ Number of times the k-means algorithm will be run with different centroid seeds.
  • max_iter : int (default: 10) ➡ Length of sampling walks.
  • tol : float (default: 1e-4) ➡ Relative tolerance of the Frobenius norm of the difference of cluster centers across consecutive iterations. Used in determining convergence.
  • algorithm : str (default: "auto") ➡ Options are lloyd, elkan, auto, full. Description here (opens in a new tab).
  • random_state : int (default: 1998) ➡ Random seed for the algorithm.

Output:

  • node: mgp.Vertex ➡ Graph node.
  • cluster_id: mgp.Number ➡ Cluster ID of the above node.

Usage:

To set clusters, use the following query:

 CALL kmeans.set_clusters(2, "embedding", "cluster_id", "k-means++", 10, 10, 0.0001, "auto", 1) YIELD node, cluster_id
 RETURN node.id as node_id, cluster_id
   ORDER BY node_id ASC;

Example

Database state

The database contains the following data:

Created with the following Cypher queries:

CREATE (:Node {id:0, embedding: [0.90678340196609497, 0.74690568447113037, -0.65984714031219482]});
CREATE (:Node {id:1, embedding: [0.90674564654509497, 0.74690568444653456, -0.64324344031219482]});
CREATE (:Node {id:2, embedding: [0.86545435296609497, 0.75434568447113037, -0.65984234323419482]});
CREATE (:Node {id:3, embedding: [0.90432432496609497, 0.65454328447113037, -0.54654234031219482]});
CREATE (:Node {id:4, embedding: [0.84654334234609497, 0.69994534547113037, -0.65383452341219482]});
CREATE (:Node {id:5, embedding: [0.26445634534539497, 0.12543654765534037, 0.463475464352342382]});
CREATE (:Node {id:6, embedding: [0.37654756534539497, 0.13455654447113037, 0.465437534631219482]});
CREATE (:Node {id:7, embedding: [0.32565654645609497, 0.24532645435433037, 0.487695876576219482]});
CREATE (:Node {id:8, embedding: [0.23565676686539497, 0.23453462345345337, 0.463457345634523482]});
MATCH (a:Node {id: 0}) MATCH (b:Node {id: 1}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 0}) MATCH (b:Node {id: 4}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 1}) MATCH (b:Node {id: 2}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 2}) MATCH (b:Node {id: 0}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 2}) MATCH (b:Node {id: 4}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 3}) MATCH (b:Node {id: 2}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 4}) MATCH (b:Node {id: 1}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 5}) MATCH (b:Node {id: 7}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 6}) MATCH (b:Node {id: 5}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 6}) MATCH (b:Node {id: 7}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 7}) MATCH (b:Node {id: 8}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 8}) MATCH (b:Node {id: 3}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 8}) MATCH (b:Node {id: 5}) MERGE (a)-[:RELATION]->(b);
MATCH (a:Node {id: 8}) MATCH (b:Node {id: 6}) MERGE (a)-[:RELATION]->(b);

Get clusters

To get clusters, use the following node:

CALL kmeans.get_clusters(2, "embedding", "k-means++", 10, 10, 0.0001, "auto", 1) YIELD node, cluster_id
RETURN node.id as node_id, cluster_id
ORDER BY node_id ASC;

Results:

+-------------------------+-------------------------+
| node_id                 | cluster_id              |
+-------------------------+-------------------------+
| 0                       | 1                       |
| 1                       | 1                       |
| 2                       | 1                       |
| 4                       | 1                       |
| 5                       | 1                       |
| 6                       | 0                       |
| 7                       | 0                       |
| 8                       | 0                       |
| 9                       | 0                       |
+-------------------------+-------------------------+