Imahn
Imahn

Reputation: 526

PyTorch: How to normalize a tensor when the image is cropped randomly?

Let's say we are working with the CIFAR-10 dataset and we want to apply some data augmentation and additionally normalize the tensors. Here is some reproducible code for this

from torchvision import transforms, datasets
import matplotlib.pyplot as plt
trafo = transforms.Compose([transforms.Pad(padding = 4, fill = 0, padding_mode = "constant"), 
                            transforms.RandomHorizontalFlip(p=0.5),
                            transforms.RandomCrop(size = (32, 32)), 
                            transforms.ToTensor(), 
                            transforms.Normalize(mean = (0.0, 0.0, 0.0), std = (1.0, 1.0, 1.0))]
                          )

cifar10_full = datasets.CIFAR10(root = "CIFAR-10", train = True, transform = trafo, target_transform = None, download = True)

The normalization I chose so far would do nothing with the tensors since I put the mean and std to 0 and 1 respectively. According to the documentation of torchvision.transforms.Normalize, the provided means and standard deviations are for each channel of the input. However, the problem is that that I cannot calculate the mean across each channel because of some random flipping and cropping mean. Therefore, my idea was something along the following lines

trafo_1 = transforms.Compose([transforms.Pad(padding = 4, fill = 0, padding_mode = "constant"), 
                            transforms.RandomHorizontalFlip(p=0.5),
                            transforms.RandomCrop(size = (32, 32)), 
                            transforms.ToTensor() 
                          )

cifar10_full = datasets.CIFAR10(root = "CIFAR-10", train = True, transform = trafo_1, target_transform = None, download = True)

Now I could calculate the mean across each channel of the input and then I wanted to normalize the tensors again. However, I cannot simply use transforms.Normalize() as cifar10_full is not the original dataset anymore, but how I could proceed instead? (One solution would be to simply fix the seed of the random generators, i.e use torch.manual_seed(0), but I would like to avoid this for now...)

Upvotes: 1

Views: 3009

Answers (1)

Zabir Al Nazi Nabil
Zabir Al Nazi Nabil

Reputation: 11218

The mean and std are not for each tensor, but from the whole dataset. What you are trying to do doesn't really matter, you just want a scale that is good enough for the whole data representation, there is no exact mean or std you will get, these are all random operations, just use the mean and std from the actual data, which is pretty much the standard.

First, try to calculate the mean and std of the dataset (try random sampling), and use that for normalization.

# Calculate the mean, std of the complete dataset
import glob
import cv2
import numpy as np 
import tqdm
import random

# calculating 3 channel mean and std for image dataset

means = np.array([0, 0, 0], dtype=np.float32)
stds = np.array([0, 0, 0], dtype=np.float32)
total_images = 0
randomly_sample = 5000
for f in tqdm.tqdm(random.sample(glob.glob("dataset_path/**.jpg", recursive = True), randomly_sample)):
    img = cv2.imread(f)
    means += img.mean(axis=(0,1))
    stds += img.std(axis=(0,1))
    total_images += 1
means = means / (total_images * 255.)
stds = stds / (total_images * 255.)
print("Total images: ", total_images)
print("Means: ", means)
print("Stds: ", stds)

Just a simple scenario, do you think in actual testing or inference your images will be augmented this way too, probably not, you will have clean images which match closely with the mean and std from the clean version of the data, so it's useless to calculate mean and std (you can take few random samples), unless you want to apply TTA.

If you want to apply TTA too, then you can go ahead and run some augmentation on the images, do random sampling and take the mean and std of those images.

Upvotes: 1

Related Questions