Reputation: 152
One of the great feature of NumPy arrays is that you can perform multidimensional slicing. I am wondering exactly how it is implemented. Let me lay out what I am thinking so far, then hopefully someone can fill in the gaps, answer some questions I have, and (probably) tell me why I'm wrong.
import numpy as np
arr = np.array([ [1, 2, 3], [4, 5, 6] ])
# retrieve the rightmost column of values for all rows
print(arr[:, 2])
# indexing a normal multidimensional list
not_an_arr = [ [1, 2, 3], [4, 5, 6] ]
print(not_an_arr[:, 2]) # TypeError: indices must be integers or slices, not tuple
At first, [:, 2]
seemed like a violation of Python syntax to me. If I tried to index a normal multidimensional list in Python I would get an error. Of course, upon actually reading the error message, I realize that the issue isn't with the syntax as I originally thought, but the type of object passed in. So the conclusion I've come to is that [:, 2]
implicitly creates a tuple, so that what's really happening in [:, 2]
is [(:, 2)]
. Is that what's happening?
I next tried to read the source code for the numpy.ndarray
class which is linked to by the ndarray documentation, but that's all in C, which I'm not proficient in, so I can't make much sense of this.
I then noticed that there was documentation for ndarray.__getitem__
. I was hoping this would lead me to the implementation of __getitem__
for the class, since my understanding is that implementing __getitem__
is where the behavior for indexing an object should be defined. My hope was that I would be able to see that they unpack the tuple and then use the slice objects or integers included in it to do the indexing on the underlying data structure however that may need to be done.
So... what really goes on behind the scenes to make multidimensional slicing work on numpy arrays?
TLDR: How is multidimensional array slicing implemented for numpy arrays?
Upvotes: 3
Views: 1625
Reputation: 231738
We can verify your first level inferences with a simple class:
In [137]: class Foo():
...: def __getitem__(self,arg):
...: print(arg)
...: return None
...:
In [138]: f=Foo()
In [139]: f[1]
1
In [140]: f[::3]
slice(None, None, 3)
In [141]: f[,]
File "<ipython-input-141-d115e3c638fb>", line 1
f[,]
^
SyntaxError: invalid syntax
In [142]: f[:,]
(slice(None, None, None),)
In [143]: f[:,:3,[1,2,3]]
(slice(None, None, None), slice(None, 3, None), [1, 2, 3])
numpy
uses code like this in np.lib.index_tricks.py
to implement "functions" like np.r_
and np.s_
. They are actually class instances that use an index syntax.
It's worth noting that it's the comma, most so than the ()
that creates a tuple:
In [145]: 1,
Out[145]: (1,)
In [146]: 1,2
Out[146]: (1, 2)
In [147]: () # exception - empty tuple, no comma
Out[147]: ()
That explains the syntax. But the implementation details are left up to the object class. list
(and other sequences like string
) can work with integers and slice
objects, but give an error when given a tuple.
numpy
is happy with the tuple. In fact passing a tuple via getitem
was added years ago to base Python because numpy
needed it. No base classes use it (that I know of); but user classes can accept a tuple, as my example shows.
As for the numpy
details, that requires some knowledge of numpy
array storage, including the role of the shape
, strides
and data-buffer. I'm not sure if I want get into those now.
A few days ago I explored one example of multidimensional indexing, and discovered some nuances that I wasn't aware of (or ever seen documented)
For most of us, understanding the how-to of indexing is more important than knowing the implementation details. I suspect there are textbooks, papers and even Wiki pages that describe 'strided' multidimensional indexing. numpy
isn't the only place that uses it.
https://numpy.org/doc/stable/reference/arrays.indexing.html
This looks like a nice intro to numpy arrays
https://ajcr.net/stride-guide-part-1/
Upvotes: 3