Mehran
Mehran

Reputation: 16821

What is the fastest way of selecting a subset of a JAX matrix?

Let's say I have a 2D matrix and I want to plot its values in a histogram. For that, I need to do something like:

list_1d = matrix_2d.reshape((-1,)).tolist()

And then use the list to plot the histogram. So far so good, it's just that there are items in the original matrix that I want to exclude. For simplicity, let's say I have a list like this:

exclude = [(2, 5), (3, 4), (6, 1)]

So, the list_1d should have all the items in the matrix without the items pointed to by the exclude (the items of exclude are row and column indices).

And BTW, the matrix_2d is a JAX array which means its content is in GPU.

Upvotes: 1

Views: 1114

Answers (1)

jakevdp
jakevdp

Reputation: 86300

One way to do this is to create a mask array that you use to select the desired subset of the array. The mask indexing operation returns a 1D copy of the selected data:

import jax.numpy as jnp
from jax import random
matrix_2d = random.uniform(random.PRNGKey(0), (10, 10))
exclude = [(2, 5), (3, 4), (6, 1)]

ind = tuple(jnp.array(exclude).T)
mask = jnp.ones_like(matrix_2d, dtype=bool).at[ind].set(False)

list_1d = matrix_2d[mask].tolist()
len(list_1d)
# 97

Upvotes: 1

Related Questions