Siddharth
Siddharth

Reputation: 98

Generating text/csv file for image path and mask path for semantic segmentation

I have a huge set of images(60k) and masks(60k) that need to be loaded into a PyTorch dataloader for semantic segmentation.

Directory Structure:

 - Segmentation
       -images
           -color_left_trajectory_3000_00001.jpg
           -color_left_trajectory_3000_00002.jpg
           ...
       -masks
           -color_segmentation_3000_00001.jpg
           -color_segmentation_3000_00002.jpg
           ...

I want to know the most efficient way to load these into a dataloader in Pytorch. I was thinking of generating a csv file with the paths to images and masks. How will I go about generating the same? Any other suggestions are appreciated!

Upvotes: 0

Views: 357

Answers (1)

Hatem
Hatem

Reputation: 521

I recommend that you make a custom subclass from the dataset class. In the init function, the paths to the images and masks are generated and then saved. This is an example:

import torch
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image

class CustomData(Dataset):
    def __init__(self,data_dir='Segmentation', data_transform=None,split= 'train'):
        self.imgs = []
        self.labels= []
        self.transform = data_transform
        self.data_dir = data_dir
        #self.imgs_dir = os.path.join(data_dir, split, 'images')
        #self.labels_dir = os.path.join(data_dir, split, 'labels')
        self.imgs_dir = os.path.join(data_dir, 'images')
        self.labels_dir = os.path.join(data_dir, 'labels')
        for img_name in os.listdir(self.imgs_dir):
            img_path = os.path.join(self.imgs_dir, img_name)
            label_name = "color_segmentation_"+"_".join(img.split('.')[0].split('_')[-2:])+'.png'
            label_path = os.path.join(self.labels_dir, label_name)
            self.imgs.append(img_path)
            self.labels.append(label_path)

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img = Image.open(self.imgs[idx])
        label =  Image.open(self.labels[idx])
        if self.transform is not None:
            img, label = self.transform(img, label)
        return img, label

class ToTensor:
    def __call__(self, image, target=None):
        image = F.to_tensor(image)
        if target is not None:
            target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target
if __name__ == '__main__':
    data = CustomData(data_transform=ToTensor)
    dataloader = DataLoader(data,batch_size=10)

Upvotes: 2

Related Questions