kett
kett

Reputation: 967

PyTorch - How to use "toPILImage" correctly

I would like to know, whether I used toPILImage from torchvision correctly. I want to use it, to see how the images look after initial image transformations are applied to the dataset.

When I use it like in the code below, the image that comes up has weird colors like this one. The original image is a regular RGB image.

This is my code:

import os
import torch
from PIL import Image, ImageFont, ImageDraw
import torch.utils.data as data
import torchvision
from torchvision import transforms    
import matplotlib.pyplot as plt

# Image transformations
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
    )
transform_img = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    normalize ])

train_data = torchvision.datasets.ImageFolder(
    root='./train_cl/',
    transform=transform_img
    )
test_data = torchvision.datasets.ImageFolder(
    root='./test_named_cl/',
    transform=transform_img                                             
    )

train_data_loader = data.DataLoader(train_data,
    batch_size=4,
    shuffle=True,
    num_workers=4) #num_workers=args.nThreads)

test_data_loader = data.DataLoader(test_data,
    batch_size=32,
    shuffle=False,
    num_workers=4)        

# Open Image from dataset:
to_pil_image = transforms.ToPILImage()
my_img, _ = train_data[248]
results = to_pil_image(my_img)
results.show()

Edit:

I had to use .data on the Torch Variable to get the tensor. Also I needed to rescale the numpy array before transposing. I found a working solution here, but it doesn't always work well. How can I do this better?

for i, data in enumerate(train_data_loader, 0):
    img, labels = data
    img = Variable(img)
    break

image = img.data.cpu().numpy()[0]

# This worked for rescaling:
image = (1/(2*2.25)) * image + 0.5

# Both of these didn't work:
# image /= (image.max()/255.0)
# image *= (255.0/image.max())

image = np.transpose(image, (1,2,0))
plt.imshow(image)
plt.show() 

Upvotes: 11

Views: 46299

Answers (3)

Jiapeng Yu
Jiapeng Yu

Reputation: 1

I recently got the same problem with you. And , I found out the reason why your image turned so different is because the 'transforms.Normalize'. When you get image from the dataset, it has been transformed by x = (x - mean)/std, where x is the image. So if you want to get normal image, you should do the inverse thing.Here is my solution.

def _inverse_norm(images):
  if isinstance(images, torch.Tensor):
  # Tensor image to numpy
      images = images.cpu().permute(1, 2, 0).numpy()
      NORM_MEAN = np.array([0.485, 0.456, 0.406])
      NORM_STD = np.array([0.229, 0.224, 0.225])
      images = (images * NORM_STD[None,None]) + NORM_MEAN[None,None]
      images = np.clip(images, a_min=0.0, a_max=1.0)
  return images

Upvotes: 0

Steven
Steven

Reputation: 5162

You can use PIL image but you're not actually loading the data as you would normally.

Try something like this instead:

import numpy as np
import matplotlib.pyplot as plt

for img,labels in train_data_loader:
    # load a batch from train data
    break

# this converts it from GPU to CPU and selects first image
img = img.cpu().numpy()[0]
#convert image back to Height,Width,Channels
img = np.transpose(img, (1,2,0))
#show the image
plt.imshow(img)
plt.show()  

As an update (02-10-2021):

import torchvision.transforms.functional as F
# load the image (creating a random image as an example)
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
pil_image = F.to_pil_image(img_data)

Alternatively

import torchvision.transforms as transforms
img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
pil_image = transforms.ToPILImage()(img_data)

The second form can be integrated with dataset loader in pytorch or called directly as so.

I added a modified to_pil_image here

essentially it does what I suggested back in 2018 but it is integrated into pytorch now.

Upvotes: 11

SpeedOfSpin
SpeedOfSpin

Reputation: 1690

I would use something like this

# Open Image from dataset:
my_img, _ = train_data[248]
results = transforms.ToPILImage()(my_img)
results.show()

Upvotes: 7

Related Questions