Hamza Usman
Hamza Usman

Reputation: 203

img should be PIL Image. Got <class 'torch.Tensor'>

I'm trying to iterate through a loader to check if it's working, however the below error is given:

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

I've tried adding both transforms.ToTensor() and transforms.ToPILImage() and it gives me an error asking for the opposite. i.e, with ToPILImage(), it will ask for tensor, and vice versa.

# Imports here
%matplotlib inline
import matplotlib.pyplot as plt
from torch import nn, optim
import torch.nn.functional as F
import torch
from torchvision import transforms, datasets, models
import seaborn as sns
import pandas as pd
import numpy as np

data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'

#Creating transform for training set
train_transforms = transforms.Compose(
[transforms.Resize(255), 
transforms.CenterCrop(224), 
transforms.ToTensor(), 
transforms.RandomHorizontalFlip(), 
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

#Creating transform for test set
test_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224), 
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

#transforming for all data
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform = test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform = test_transforms)

#Creating data loaders for test and training sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32, 
shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
images, labels = next(iter(trainloader))

It should allow me to simply see the image once I run plt.imshow(images[0]), if its working correctly.

Upvotes: 20

Views: 42690

Answers (2)

Biplob Das
Biplob Das

Reputation: 3114

Just add transforms.ToPILImage() to convert into pil image and then it will work, example:

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Upvotes: 17

Anubhav Singh
Anubhav Singh

Reputation: 8719

transforms.RandomHorizontalFlip() works on PIL.Images, not torch.Tensor. In your code above, you are applying transforms.ToTensor() prior to transforms.RandomHorizontalFlip(), which results in tensor.

But, as per the official pytorch documentation here,

transforms.RandomHorizontalFlip() horizontally flip the given PIL Image randomly with a given probability.

So, just change the order of your transformation in above code, like below:

train_transforms = transforms.Compose([transforms.Resize(255), 
                                       transforms.CenterCrop(224),  
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(), 
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 

Upvotes: 35

Related Questions