Colin Talbert
Colin Talbert

Reputation: 494

numpy slice to return last two dimensions

Basically I'm looking for a function or syntax that will allow me to get the first 'slice' of the last two dimensions of a n dimensional numpy array with an arbitrary number of dimensions.

I can do this but it's too ugly to live with, and what if someone sends a 6d array in? There must be a numpy function like the ellipse that expands to 0,0,0,... instead of :,:,:,...

data_2d = np.ones(5**2).reshape(5,5)
data_3d = np.ones(5**3).reshape(5,5,5)
data_4d = np.ones(5**4).reshape(5,5,5,5)

def get_last2d(data):
    if data.ndim == 2:
        return data[:]
    if data.ndim == 3:
        return data[0, :]
    if data.ndim == 4:
        return data[0, 0, :]

np.array_equal(get_last2d(data_3d), get_last2d(data_4d))

Thanks, Colin

Upvotes: 2

Views: 2278

Answers (2)

jme
jme

Reputation: 20765

def get_last_2d(x):
    m,n = x.shape[-2:]
    return x.flat[:m*n].reshape(m,n)

This works because flattening an array returns entries in order of the fasted-varying index, and for C-style indexing, the last indices vary the fastest. So the first m*n entries of the flattened array are what you want.

Upvotes: 0

farenorth
farenorth

Reputation: 10791

How about this,

def get_last2d(data):
    if data.ndim <= 2:
        return data
    slc = [0] * (data.ndim - 2)
    slc += [slice(None), slice(None)]
    return data[slc]

Upvotes: 2

Related Questions