Reputation: 221
I am trying to find the most efficient way to do slicing for a 3D numpy array. This is a subset of the data, just for test purposes :
in_arr =np.array([[[0,1,2,5],[2,3,2,6],[0,1,3,2]],[[1,2,3,4],[3,1,0,5],[2,4,0,1]]])
indx =[[3,1,2],[2,0,1]]
I need to get the value at the indx as stated. For example, indx[0][0] is 3, so I am looking for the 3rd elem of in_arr[0][0], in this case, 5.
I have the following code that will do what i need it to do, but the time complexeity is n^2, which I am not happy about.
list_in =[]
for x in range(len(indx)):
arr2 = []
for y in range(len(indx[x])):
arr2.append(in_arr[x][y][indx[x][y]])
#print in_arr[x][y][indx[x][y]]
list_in.append(arr2)
print list_in
I am looking for a very fast and efficient way to do the same task for a large dataset.
Upvotes: 0
Views: 273
Reputation: 86320
You can do this efficiently using broadcasted arrays of indices; for example:
i1 = np.arange(2)[:, np.newaxis]
i2 = np.arange(3)[np.newaxis, :]
i3 = np.array(indx)
in_arr[i1, i2, i3]
# array([[5, 3, 3],
# [3, 3, 4]])
What numpy does here is to effectively match the entries of the three index arrays, and extract the associated entries from in_arr
: the reason for the [:, np.newaxis]
and [np.newaxis, :]
terms is that it reshapes the three arrays to be compatible via numpy's broadcasting rules.
Upvotes: 1