memo
memo

Reputation: 3714

normalise batch of images in numpy per channel

I have a numpy ndarray of shape [batch_size, width, height, num_channels] (they're not RGB images, but similar concept of 2D fields).

I'd like to normalise these images per channel. Is there a more numpy way of doing this than the below? Particularly what I don't like is the loops over channels, and I found it weird having to do np.min and np.max on a slice. Also this is hardcoded to only work on tensors with rank 4, how could it be adapted to have dynamic rank, or the channels on a dynamic axis?

def get_img_ch_min_max(imgs):
    '''return minimum and maximum for each channel of [batch, width, height, channels]'''
    if len(imgs.shape)==3: imgs = np.expand_dims(imgs, axis=-1)
    # iterate each channel
    ch_min = np.array([np.min(imgs[:,:,:,i]) for i in range(imgs.shape[-1])])
    ch_max = np.array([np.max(imgs[:,:,:,i]) for i in range(imgs.shape[-1])])
    return ch_min, ch_max


def normalise_per_channel(imgs):
    '''normalise batch of images per channel, [batch, width, height, channels]'''
    if len(imgs.shape)==3: imgs = np.expand_dims(imgs, axis=-1)
    ch_min, ch_max = get_img_ch_min_max(imgs)
    ch_range = ch_max - ch_min
    imgs_ret = np.copy(imgs)
    for i in range(imgs.shape[-1]): # iterate each channel
        if ch_range[i] > 0: # avoid divide by zero
            imgs_ret[:,:,:,i] = (imgs[:,:,:,i] - ch_min[i]) / ch_range[i]
    imgs_ret = np.squeeze(imgs_ret)
    return imgs_ret

Upvotes: 3

Views: 1725

Answers (1)

AGN Gazer
AGN Gazer

Reputation: 8378

This is my attempt to answer your question without having tested my solution on your data.

I think the key idea here is to use numpy.amin() (or amax()) and specify axis argument. This will help avoid the loop:

rank = 4
ch_mins = np.amin(imgs, axis=tuple(range(rank - 1)))
ch_maxs = np.amax(imgs, axis=tuple(range(rank - 1)))
ch_range = ch_max - ch_min
idx = np.where(ch_range == 0)
ch_mins[idx] = 0
ch_range[idx] = 1
imgs = (imgs - ch_mins) / ch_range

For "dynamic axis" situation, I would suggest that you use newaxis or [None,:,None,etc.] appropriately and (if necessary) use numpy.moveaxis() or one of the similar functions.

Upvotes: 1

Related Questions