Camilo Martínez M.
Camilo Martínez M.

Reputation: 1620

Fastest way of creating a dictionary which acts as a group-by style look-up in a 2D numpy array?

Suppose I have a 2D numpy array whose values correspond to a label or class. For example, if A = [[0, 0, 1, 1], [1, 1, 1, 0], then positions (0, 0), (0, 1), (1, 3) correspond to the class '0' and (0, 2), (0, 3), (1, 0), etc correspond to class '1'. This is a very simple example but, in general, I would be dealing with matrices with much more items.

What I want to do is essentially build a dictionary where a key corresponds to each class and its corresponding value is a list of tuples, where each tuple corresponds to a position of the input matrix whose value is the key. In other words, group the input matrix by its values and obtain a list of positions where each unique value occurs.

For now, I have the following code:

S = {i: [] for i in range(A.max() + 1)}
for i in range(A.shape[0]):
    index = np.arange(A[i].shape[0])
    sort_idx = np.argsort(A[i])
    cnt = np.bincount(A[i])
    result = np.split(index[sort_idx], np.cumsum(cnt[:-1]))
    for j, k in enumerate(result):
        S[j] += [(i, z) for z in k]

Where A is my input matrix. That takes around 0.4 ms to run on average on a 500x500 matrix. Nonetheless, I feel like it can be improved further, by making a better use of vectorization (maybe).

Could someone guide me on how it could be made simpler and/or faster? Any help is appreciated. Thanks!

Upvotes: 3

Views: 198

Answers (1)

Susmit Agrawal
Susmit Agrawal

Reputation: 3764

You can do this much more simply using np.argwhere and np.unique:

S = {}
for key in np.unique(A):
    S[key] = np.argwhere(A==key)

Note that this returns a 2D numpy array.

Upvotes: 5

Related Questions