Reputation: 494
Basically I'm looking for a function or syntax that will allow me to get the first 'slice' of the last two dimensions of a n dimensional numpy array with an arbitrary number of dimensions.
I can do this but it's too ugly to live with, and what if someone sends a 6d array in? There must be a numpy function like the ellipse that expands to 0,0,0,... instead of :,:,:,...
data_2d = np.ones(5**2).reshape(5,5)
data_3d = np.ones(5**3).reshape(5,5,5)
data_4d = np.ones(5**4).reshape(5,5,5,5)
def get_last2d(data):
if data.ndim == 2:
return data[:]
if data.ndim == 3:
return data[0, :]
if data.ndim == 4:
return data[0, 0, :]
np.array_equal(get_last2d(data_3d), get_last2d(data_4d))
Thanks, Colin
Upvotes: 2
Views: 2278
Reputation: 20765
def get_last_2d(x):
m,n = x.shape[-2:]
return x.flat[:m*n].reshape(m,n)
This works because flattening an array returns entries in order of the fasted-varying index, and for C-style indexing, the last indices vary the fastest. So the first m*n entries of the flattened array are what you want.
Upvotes: 0
Reputation: 10791
How about this,
def get_last2d(data):
if data.ndim <= 2:
return data
slc = [0] * (data.ndim - 2)
slc += [slice(None), slice(None)]
return data[slc]
Upvotes: 2