Oscar Johansson
Oscar Johansson

Reputation: 25

PyTorch - Import dataset with images as labels

I have a dataset containing images as inputs and labels/targets as images as well. The structure in the folder is as follows:

> DATASET/
> ---TRAIN/
> ------image_xx.png
> ------label_xx.png
> ---TEST/
> ------image_xx.png
> ------label_xx.png

I've currently tried to use "ImageFolder" from torchvisions datasets to load the images as follows:

TRAIN_PATH = '/path/to/dataset/DATASET'
train_data = datasets.ImageFolder(root=TRAIN_PATH, transform=transforms.ToTensor())
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)

However as shown below:

for img, label in train_loader:
    print(img.shape)
    print(label.shape)
    break

torch.Size([16, 3, 128, 128])
torch.Size([16])

The labels aren't images but rather indicies or something similar. Is there a convenient way of importing this dataset with the aforementioned structure?

Upvotes: 1

Views: 2321

Answers (1)

GoodDeeds
GoodDeeds

Reputation: 8527

The ImageFolder dataset is suitable when you have discrete, scalar classes for each image. It expects the directory structure to be such that each subdirectory contains a certain class.

For your case, you can simply define your own subclass of torch.nn.Dataset. This tutorial may be helpful.

A simple example (I have not tried running it to see if it works correctly):

import torch
import os
import cv2

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, root_path, transform=None):
        self.data_paths = [f for f in sorted(os.listdir(root_path)) if f.startswith("image")]
        self.label_paths = [f for f in sorted(os.listdir(root_path)) if f.startswith("label")]
        self.transform = transform

    def __getitem__(self, idx):
        img = cv2.imread(self.data_paths[idx])
        label = cv2.imread(self.label_paths[idx])
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data_paths)

TRAIN_PATH = '/path/to/dataset/DATASET/TRAIN/'
train_data = MyDataset(root_path=TRAIN_PATH, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)

Upvotes: 2

Related Questions