user3582433
user3582433

Reputation: 469

Pytorch: load dataset of grayscale images

I want to load a dataset of grayscale images. I used ImageFolder but this doesn't load gray images by default as it converts images to RGB.

I found solutions that load images with ImageFolder and after convert images in grayscale, using:

transforms.Grayscale(num_output_channels=1)

or

ImageOps.grayscale(image)

Is it correct? How can I load grayscale imaged without conversion? I try ImageDataBunch, but I have problems to import fastai.vision

Upvotes: 3

Views: 10247

Answers (3)

sailfish009
sailfish009

Reputation: 2929

Make custom loader, feed it to ImageFolder:

import numpy as np
from PIL import Image, ImageOps

def gray_reader(image_path):
    im = Image.open(image_path)
    im2 = ImageOps.grayscale(im)
    im.close()
    return np.array(im2)   # return np array
    # return im2           # return PIL Image

some_dataset = ImageFolder(image_root_path, loader=gray_reader)

Edit:

Below code is much better than previous, get color image and convert to grayscale in transform()

def get_transformer(h, w):
    valid_transform = transforms.Compose([            
        transforms.ToPILImage(),                                    
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((h, w)),                                    
        transforms.ToTensor(),                                    
        transforms.Normalize([0.5], [0.5]) ])
    return valid_transform

Upvotes: 1

Tanya Jain
Tanya Jain

Reputation: 483

Assuming the dataset is stored in the "Dataset" folder as given below, set the root directory as "Dataset":

Dataset

  • class_1
    • img1.png
    • img2.png
  • class_2
    • img1.png
    • img2.png
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

root = 'Dataset/'

data_transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                     transforms.ToTensor()])
dataset = ImageFolder(root, transform=data_transform)

For reference, train and test dataset are being split into 70% and 30% respectively.

# Split test and train dataset 
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size
train_data, test_data = random_split(dataset, [train_size, test_size])

This dataset can be further divided into train and test data loaders as given below to perform operation in batches.

Usually you will see the dataset is assigned batch_size once to be used for both train and test loaders. But, I try to define it separately. The idea is to give the batch_size such that it is a factor of the train/test data loader's size, otherwise it will give an error.

# Set batch size of train data loader
batch_size_train = 20

# Set batch size of test data loader
batch_size_test = 22

# load the split train and test data into batches via DataLoader()
train_loader = DataLoader(train_data, batch_size=batch_size_train, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size_test, shuffle=True)

Upvotes: 7

Szymon Maszke
Szymon Maszke

Reputation: 24681

Yes, that is correct and AFAIK pillow by default loads images in RGB, see e.g. answers to this question. So conversion to grayscale is the only way, though takes time of course.

Pure pytorch solution (if ImageFolder isn't appropriate)

You can roll out your own data loading functionalities and If I were you I wouldn't go fastai route as it's pretty high level and takes away control from you (you might not need those functionalities anyway).

In principle, all you have to do is to create something like this below:

import pathlib

import torch
from PIL import Image


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, path: pathlib.Path, images_class: int, regex="*.png"):
        self.files = [file for file in path.glob(regex)]
        self.images_class: int = images_class

    def __getitem__(self, index):
        return Image.open(self.files[index]).convert("LA"), self.images_class


# Assuming you have `png` images, can modify that with regex
final_dataset = (
    ImageDataset(pathlib.Path("/path/to/dogs/images"), 0)
    + ImageDataset(pathlib.Path("/path/to/cats/images"), 1)
    + ImageDataset(pathlib.Path("/path/to/turtles/images"), 2)
)

Above would get you images from the paths provided above and each image would return appropriate provided class.

This gives you more flexibility (different folder setting than torchvision.datasets.ImageFolder) for a few more lines.

Ofc, you could add more of those or use loop or whatever else. You could also apply torchvision.transforms, e.g. transforming images above to tensors, read

torchdata solution

Disclaimer, author here. If you are cocerned about loading times of your data and grayscale transformation you could use torchdata third party library for pytorch.

Using it one could create the same thing as above but use cache or map (to use torchvision.transforms or other transformations easily) and some other things known e.g. from tensorflow.data module, see below:

import pathlib

from PIL import Image

import torchdata


# Change inheritance
class ImageDataset(torchdata.Dataset):
    def __init__(self, path: pathlib.Path, images_class: int, regex="*.png"):
        super().__init__()  # And add constructor call and that's it
        self.files = [file for file in path.glob(regex)]
        self.images_class: int = images_class

    def __getitem__(self, index):
        return Image.open(self.files[index]), self.images_class


final_dataset = (
    ImageDataset(pathlib.Path("/path/to/dogs/images"), 0)
    + ImageDataset(pathlib.Path("/path/to/cats/images"), 1)
    + ImageDataset(pathlib.Path("/path/to/turtles/images"), 2)
).cache()  # will cache data in-memory after first pass
# You could apply transformations after caching for possible speed-up

torchvision ImageFolder loader

As correctly pointed out by @jodag in the comments, one can use loader callable with single argument path to do customized data opening, e.g. for grayscale it could be:

from PIL import Image

import torchvision

dataset = torchvision.datasets.ImageFolder(
    "/path/to/images", loader=lambda path: Image.open(path).convert("LA")
)

Please notice you could also use it for other types of files, those doesn't have to be images.

Upvotes: 2

Related Questions