Reputation: 16821
Let's say I have a 2D matrix and I want to plot its values in a histogram. For that, I need to do something like:
list_1d = matrix_2d.reshape((-1,)).tolist()
And then use the list to plot the histogram. So far so good, it's just that there are items in the original matrix that I want to exclude. For simplicity, let's say I have a list like this:
exclude = [(2, 5), (3, 4), (6, 1)]
So, the list_1d
should have all the items in the matrix without the items pointed to by the exclude
(the items of exclude
are row and column indices).
And BTW, the matrix_2d
is a JAX array which means its content is in GPU.
Upvotes: 1
Views: 1114
Reputation: 86300
One way to do this is to create a mask array that you use to select the desired subset of the array. The mask indexing operation returns a 1D copy of the selected data:
import jax.numpy as jnp
from jax import random
matrix_2d = random.uniform(random.PRNGKey(0), (10, 10))
exclude = [(2, 5), (3, 4), (6, 1)]
ind = tuple(jnp.array(exclude).T)
mask = jnp.ones_like(matrix_2d, dtype=bool).at[ind].set(False)
list_1d = matrix_2d[mask].tolist()
len(list_1d)
# 97
Upvotes: 1