Tue
Tue

Reputation: 442

crop last 2 dimensions of n dimensional array

Is there an elegant way of cropping the last 2 dimensions an n dimensional array, where n>=2?

Essentially I want to do the following but in a more elegant way:

def crop(self,img,row,col,l):
    if img.ndim == 2:
        img_crop = img[row:row + l, col:col + l] # 2D case
    elif img.ndim == 3:
        img_crop = img[:, row:row + l, col:col + l] # 3D case
    elif img.ndim == 4:
        img_crop = img[:,:, row:row + l, col:col + l] # 4D case
    ...
    return img_crop

Upvotes: 1

Views: 171

Answers (1)

Quang Hoang
Quang Hoang

Reputation: 150825

I believe you can use ... to replace all the previous axes:

img_crop = img[..., row:row + l, col:col + l]

Upvotes: 1

Related Questions