Cypher
Cypher

Reputation: 2591

Numpy - Filtering The Array Based on Indices and Respecting Axis

I have an integer ids array and a float distances_per_batch array:

BATCH_SIZE = 500
ARRAY_SIZE = 10000

ids = np.arange(ARRAY_SIZE) # Shape = ARRAY_SIZE,
distances_per_batch = np.random.rand(BATCH_SIZE, ARRAY_SIZE) # Shape = BATCH_SIZE, ARRAY_SIZE

I'm trying to get the ids where their distance is higher than 0.9:

ids_expanded = np.repeat(np.expand_dims(ids, axis=0), BATCH_SIZE, axis=0) # Shape = BATCH_SIZE, ARRAY_SIZE (Not sure if this is even right to use since it takes a while for larger BATCH_SIZE & ARRAY_SIZE and seems to create a new array
selected_ids = ids_expanded[distances_per_batch > 0.9]

I'm expecting selected_ids to have a 2-Dimensional shape of (500,?) to get the ids which have a distance greater than 0.9 for each entry in the batch (total 500 entries) but the final result is automatically reshaped to a 1-Dimensional array and I can't decide which selected id belongs to which of the 500 entries...

How can I get the desired results in a fast and proper way (not looping through every record one by one and using Numpy's faster methods)? I'm not even sure if expanding dimensions and repeating the array is the proper way since it takes a while for larger BATCH_SIZE & ARRAY_SIZE and seems to create a new array.

Upvotes: 0

Views: 74

Answers (1)

splash58
splash58

Reputation: 26153

np.where(distances_per_batch > 0.9) 

returns separate arrays of rows and columns indexes. To gather them

np.transpose(np.where(distances_per_batch > 0.9))

with any random data it returns

array([[   0,    0],
       [   0,   31],
       [   0,   33],
       ..., 
       [ 499, 9988],
       [ 499, 9993],
       [ 499, 9995]], dtype=int32)

Upvotes: 1

Related Questions