Waleed Esmail
Waleed Esmail

Reputation: 87

Numpy implementation of simple clustering algorithm

I would like to implement a simple clustering algorithm using python. First I will describe the problem:

I have some points each point is represented by an id, and there a pair probability between each pair i.e prob(id1, id2)=some_value. This is arranged in numpy array of shape [N,3], where N is the number of all possible point pairs. To make more clear here is an example array:

a = np.array([[1,2, 0.9],
     [2,3, 0.63],
     [3,4, 0.98],
     [4,5, 0.1],
     [5,6, 0.98],
     [6,7, 1]])

where the first two entries are the point ids and the third entry is the probability that they belong to each other.

The clustering problem is connect points that pass probability cut cut=0.5 i.e. points 1,2,3,4 belong to the same cluster and 5,6,7 belong to another cluster. The current solution that I have is make a list of lists(of point ids) i.e l=[[1,2,3,4],[5,6,7]] by looping twice over the unique point ids and array a. Is there a smarter and faster way to do this?

Upvotes: 1

Views: 164

Answers (1)

Mad Physicist
Mad Physicist

Reputation: 114230

The problem you describe is a graph problem. Many common graph algorithms are implemented in the networkx package.

import numpy as np
import networkx as nx
threshold = 0.5

If your threshold is written in stone, you can pre-apply it, and build your Graph from the remaining data:

G = nx.Graph()
G.add_weighted_edges_from(a[a[:, -1] >= threshold])

If you want to play with your threshold in the graph, you can build with the whole array and remove edges in the graph. This will be slower than the pre-processed version:

G = nx.Graph()
G.add_weighted_edges_from(a)
G.remove_edges_from(e for e in g.edges if g.get_edge_data(*e) < threshold)

or alternatively

G.remove_edges_from(a[a[:, -1] < threshold]])

or

G.remove_edges_from(r for r in a if r[-1] < threshold)

Whether you construct a reduced graph or reduce it by removing edges, you can get the number and contents of clusters by looking at the components:

>>> nx.number_connected_components(G)
2
>>> for c in nx.connected_components(G):
...     print(c)
{1.0, 2.0, 3.0, 4.0}
{5.0, 6.0, 7.0}

Upvotes: 2

Related Questions