Beom Seok Park
Beom Seok Park

Reputation: 21

How to normalize images in PyTorch

transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.ImageFolder(root='C:/Users/beomseokpark/Desktop/CNN/train_data', transform = transform)
data_loader = DataLoader(dataset = trainset, batch_size = 8, shuffle = True, num_workers=2)

with torch.no_grad():
    for num, data in enumerate(trainset):
        imgs, label = data

I loaded images with ImageFolder in torchvision library, and how can I get mean and std from each channel of my images?

Can anyone please help me out?

Upvotes: 2

Views: 2914

Answers (2)

Shai
Shai

Reputation: 114796

There's the "lazy man" approach: You can simply plug a nn.BatchNorm2d as the very first layer of your network. With the appropriate momentum, and track_running_stats=True this layer will estimate your data's mean and variance for you.

Alternatively, you can compute the mean and variance using

mu = torch.zeros((3,), dtype=torch.float)
sig = torch.zeros((3,), dtype=torch.float)
n = 0
with torch.no_grad():
    for num, data in enumerate(trainset):
        imgs, _ = data 
        mu += torch.sum(imgs, dim=(0, 2, 3))
        sig += torch.sum(imgs**2, dim=(0, 2, 3))
        n += imgs.numel() // imgs.shape[0]
n  = float(n)
mu = mu / n  # mean
sig = sig / n - (mu ** 2)

Upvotes: 3

Alexey Golyshev
Alexey Golyshev

Reputation: 812

import torch as t

batch_size = 8
imgs = t.empty(batch_size, 3, 128, 128).normal_()

t.nn.Flatten(start_dim=1)(imgs.permute(1, 0, 2, 3)).mean(dim=1)
t.nn.Flatten(start_dim=1)(imgs.permute(1, 0, 2, 3)).std(dim=1).shape

torch.Size([3])

Upvotes: 2

Related Questions