cmplx96
cmplx96

Reputation: 1651

Numpy - Normalize RGB image dataset

My dataset is a Numpy array with dimensions (N, W, H, C), where N is the number of images, H and W are height and width respectively and C is the number of channels.

I know that there are many tools out there but I would like to normalize the images with only Numpy.

My plan is to compute the mean and standard deviation across the whole dataset for each of the three channels and then subtract the mean and divide by the standard deviation.

Suppose we have two images in the dataset and and the first channel of those two images looks like this:

x=array([[[3., 4.],
          [5., 6.]],

          [[1., 2.],
          [3., 4.]]])

Compute the mean:

numpy.mean(x[:,:,:,0])
= 3.5

Compute the std:

numpy.std(x[:,:,:,0])
= 1.5

Normalize the first channel:

x[:,:,:,0] = (x[:,:,:,0] - 3.5) / 1.5

Is this correct?

Thanks!

Upvotes: 5

Views: 14309

Answers (1)

mostsquares
mostsquares

Reputation: 933

Looks good, but there are some things NumPy does that could make it nicer. I'm assuming that you want to normalize each channel separately.

For one, notice that x has a method mean, so we can write x[..., 0].mean() instead of np.mean(x[:, :, :, 0]). Also, the mean method takes the keyword argument axis, which we can use as follows:

means = x.mean(axis=(0, 1, 2)) # Take the mean over the N,H,W axes
means.shape # => will evaluate to (C,)

Then we can subtract the means from the whole dataset like so:

centered = x - x.mean(axis=(0,1,2), keepdims=True)

Note that we had to use keepdims here.

There is also an x.std that works the same way, so we can do the whole normalization in 1 line:

z = (x - x.mean(axis=(0,1,2), keepdims=True)) / x.std(axis=(0,1,2), keepdims=True)

Check out the docs for numpy.ndarray.mean and np.ndarray.std for more info. The np.ndarray.method methods are what you hit when you call x.method instead of using np.method(x) instead.


Edit: I have since learned that, of course, there is a scipy.stats.zscore. I'm not sure if this is a more readable way to take zscores along each channel, but some might prefer it:

z = zscore(x.reshape(-1, 3)).reshape(x.shape)

The scipy function operates only over a single axis, so we have to reshape into an NHW x C matrix first and then reshape back.

Upvotes: 12

Related Questions