Rahul Bohare
Rahul Bohare

Reputation: 821

(image, mask) pair do not match one another in a semantic segmentation task

I am writing a simple custom DataLoader (which I will add more features to later) for a segmentation dataset but the (image, mask) pair I return using __getitem()__ method are different; the returned mask belongs to a different image than the one which is returned. My directory structure is /home/bohare/data/images and /home/bohare/data/masks .

Following is the code I have:

import torch
from torch.utils.data.dataset import Dataset
from PIL import Image
import glob
import os
import matplotlib.pyplot as plt

class CustomDataset(Dataset):
    def __init__(self, folder_path):
        
        self.img_files = glob.glob(os.path.join(folder_path,'images','*.png'))
        self.mask_files = glob.glob(os.path.join(folder_path,'masks','*.png'))
    
    def __getitem__(self, index):
        
        image = Image.open(self.img_files[index])
        mask = Image.open(self.mask_files[index])
        
        return image, mask
    
    def __len__(self):
        return len(self.img_files)
data = CustomDataset(folder_path = '/home/bohare/data')
len(data)

This code correctly gives out the total size of the dataset.

But when I use: img, msk = data.__getitem__(n) where n is the index of any (image, mask) pair and I plot the image and mask, they do not correspond to one another.

How can I modify/what can I add to the code to make sure the (image, mask) pair are returned correctly? Thanks for the help.

Upvotes: 1

Views: 855

Answers (1)

David
David

Reputation: 8298

glob.glob is returing it without order, glob.glob calls internally os.listdir:

os.listdir(path) Return a list containing the names of the entries in the directory given by path. The list is in arbitrary order. It does not include the special entries '.' and '..' even if they are present in the directory.

To solve it, you can just sort both so that the order will be the same:

self.img_files = sorted(glob.glob(os.path.join(folder_path,'images','*.png')))
self.mask_files = sorted(glob.glob(os.path.join(folder_path,'masks','*.png')))

Upvotes: 1

Related Questions