Viktor hj
Viktor hj

Reputation: 11

Pytorch Loading a custom dataset

Hey I've been seatching around the web looking for help on importing a custom image set, but every good tutorial seems to just MNIST which is fine, but I dont know how to translate code to a custom set. I've got a folder structure like this: SET:
|-->Training
-----|-->A
---------|-->8000 items
-----|-->B
----------|-->8000 items
|-->Validation
-----|-->A
----------|-->600 items
-----|-->B
----------|-->600 items

I want to train a GAN on the set of 8000 input images in Training set A to hopefully learn to mimic Training set B

I've been having no luck understand all the self inheritance from MNIST and how to use that with a custom set

Upvotes: 1

Views: 902

Answers (2)

Sarah Belouchi
Sarah Belouchi

Reputation: 1

Here is how to make elements of your dataloader customized.

Let's start by how default torch dataloader looks like:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)
       

Let's first look how to create your "dataset". My example follows the logic in this tutorial but for your dataset structure (with A and B image folder): https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

from torch.utils.data import Dataset
class GANDataset(Dataset):

    def __init__(self, csv_file, sampler, transform, mode='train'):
        """
        Arguments:
            csv_file: Path to the csv file, 
            in your case, let's say it has two columns "A" and "B" and, in the case of your training set, with 8000 rows. For example, item 0 at column A is the full path to your first training input image,
            and item 0 at column B is the full path to your first training ground truth image.
            transform: transforms to be applied on a sample. I will show you a minimal transform composite later.
            sampler: you need some sort of sampler, which can be a costum one or one of torch samplers. I will show how a simple sampler looks like.
            mode: 'train' or 'inference'
        """
        self.dataset = read_data_from_csv(csv_file)
        self.dataset_indices = sampler.return_indices(self.dataset, mode)
        self.mode = mode
        self.transform = transform

    def __len__(self):
        return len(self.dataset_indices)

    def __getitem__(self, idx):

        img_A, img_B = self.read_images(idx)
        to_move_on = False
        if self.mode == 'train':
            if img_A is not None and img_B is not None:
                to_move_on = True
        if self.mode == 'inference':
            if img_A is not None:
                to_move_on = True
        if  to_move_on: #in training mode both input and target should be not None, in inference mode only the input should not be None.
            data_sample = copy.deepcopy(self.dataset[idx])
            data_sample['img_B'] = img_B
            data_sample['img_A'] = img_A
            data_sample = self.transform(data_sample)
            return data_sample
        else: #if you fail for any reason to read this sample you can pick up another idx so that your training process doesn't stop.
            another_idx = random.randint(0, len(self.dataset_indices) - 1)
            return self[self.dataset_indices[another_idx]]

    def read_images(self, idx):
        img_A = None
        img_B = None
        try:
            if self.dataset[idx]['A'] is not None:
                img_A = cv2.imread(self.dataset[idx]['A'])
        except:
            img_A = None
            
        try:
            if self.dataset[idx]['B'] is not None:
                img_B = cv2.imread(self.dataset[idx]['B'])
        except:
            img_B = None

        return img_A, img_B
    

The function "read_data_from_csv" should be defined by you to build a dict-like dataset from your csv file. It can be as simple as this:

def read_data_from_csv(csv_file):
    df= pd.read_csv(csv_file)
    dataset = []
    for index, row in df.iterrows():
        if not pd.isnull(row['A']) and row['A'] is not None:
            A_file = row['A'])
        else:
            A_file = None

        if not pd.isnull(row['B']) and row['B'] is not None:
            B_file = row['B']
        else:
            B_file = None
            
        if A_file is not None or B_file is not None:
            new_data_item = {'A': A_file, 'B': B_file,
                            'img_A': None, 'img_B': None}
            dataset.append(new_data_item)
    return dataset

Let's look at a simple sampler with the most basic possibilit to sample, which should be enough for a GAN model. Look here for more: https://github.com/pytorch/pytorch/blob/main/torch/utils/data/sampler.py

class NaiveSampler(object):
    def return_indices(self, dataset, mode):
        dataset_indices = [idx for idx in range(len(dataset))]
        if mode == 'train':
            random.shuffle(dataset_indices)
        return dataset_indices  

Let's look at basic transforms. This example first decides whether a data augmentation should happen or not, in my case I want the sample to be flipped by a certain probability (part of your training hyper parameters) and then some other essential changes (like normalization, standardization and moving to torch tensor from numpy array). For mode details look at this tutorial: https://pytorch.org/vision/stable/transforms.html

transform=transforms.Compose([FlipAugmentation(hflip_probability, vflip_probability),
                              SampleGenerator(mean_A, std_A, mean_B, std_B)])

Where, my augmentation class could look like this:

class FlipAugmentation(object):

    def __init__(self, hflip_probability, vflip_probability):

        augmentation_list = []
        self.sequence_augmentation = lambda x: x

        if hflip_probability > 0:
            augmentation_list.append(HorizontalFlip(hflip_probability))
        if vflip_probability > 0:
            augmentation_list.append(VerticalFlip(vflip_probability))

        if len(augmentation_list) > 0:
            def compose(g, f):
                return lambda x: g(f(x))
            self.sequence_augmentation = reduce(compose, augmentation_list, lambda x: x)

    def __call__(self, sample):
        return self.sequence_augmentation(sample)
        
class HorizontalFlip(object):

    def __init__(self, hflip_probability):
        self.hflip_probability = hflip_probability

    def __call__(self, sample):
        if np.random.rand() < self.hflip_probability:

            returned_sample = copy.deepcopy(sample)
            img_A = returned_sample['img_A']
            img_A = img_A[:, ::-1, :]
            returned_sample['img_A'] = img_A

            img_B = returned_sample['img_B']
            img_B = img_B[:, ::-1, :]
            returned_sample['img_B'] = img_B
            return returned_sample
        else:
            return sample
            
class VecrticalFlip(object):

    def __init__(self, vflip_probability):
        self.vflip_probability = vflip_probability

    def __call__(self, sample):
        if np.random.rand() < self.vflip_probability:

            returned_sample = copy.deepcopy(sample)
            img_A = returned_sample['img_A']
            img_A = img_A[::-1, :, :]
            returned_sample['img_A'] = img_A

            img_B = returned_sample['img_B']
            img_B = img_B[::-1,:, :]
            returned_sample['img_B'] = img_B
            return returned_sample
        else:
            return sample
            
            

And my sample generator is as simple as this, which standardizes input and target images:

class SampleGenerator(object):

    def __init__(self, mean_A, std_A, mean_B, std_B):
        self.normalize_means_stds = [mean_A, std_A, mean_B, std_B]

    def __call__(self, sample):
        img_A, img_B = sample['img_A'], sample['img_B']
        sample_called = dict(sample)
        img_A = (((img_A.astype(np.float32) / 255.0) - self.normalize_means_stds[0]) / self.normalize_means_stds[1])
        sample_called['img_A'] = torch.from_numpy(img_A)
        if img_B is not None:
            img_B = (((img_B.astype(np.float32) / 255.0) - self.normalize_means_stds[2]) / self.normalize_means_stds[3])
            sample_called['img_B'] = torch.from_numpy(img_B)
        return sample_called

You finally need a batch sampler too, which can again be as simple as this (for more details look at here: https://medium.com/@haleema.ramzan/how-to-build-a-custom-batch-sampler-in-pytorch-ce04161583ee#:~:text=from%20torch.utils.data.sampler%20import%20Sampler-,class%20CustomBatchSampler(Sampler)%3A,-r%22%22%22Yield%20a%20mini%2Dbatch%20of%20indices.%20The):

from torch.utils.data.sampler import Sampler
class CustomBatchSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.batch_size = batch_size
        self.data = data
        index_list = self.data.dataset_indices
        number_batches = self.__len__()
        batch_list = []
        for start in range(0, number_batches * self.batch_size, self.batch_size):
            batch = []
            for b in range(start, start + self.batch_size):
                batch.append(index_list[b % len(index_list)])
            batch_list.append(batch)
        self.batch_list = batch_list
        
    def __iter__(self):
        for batch in self.batch_list:
            yield batch

    def __len__(self):
        return len(self.data) // self.batch_size
    

It has now all come together and you can create your dataloader:

csv_file = 'path/to/your/file.csv'
sampler = NaiveSampler()
transform = transforms.Compose([FlipAugmentation(0.3, 0.3),
                               SampleGenerator(0, 1.0, 0, 1.0)])
train_dataset = GANDataset(csv_file, sampler, transform, mode='train')
batch_sampler = CustomBatchSampler(train_dataset, batch_size=16)
train_dataloader = DataLoader(train_dataset, num_workers=0,
                              batch_sampler=batch_sampler,
                              pin_memory=torch.cuda.is_available())

Upvotes: 0

Domenico
Domenico

Reputation: 176

You need to read your image files with a class that derives from the torch.utils.data.Dataset class, in order to have your custom dataset You can follow this part of the documentation to have a basic example of how to populate a custom Dataset. So, After you define

from torch.utils.data import Dataset
class CustomImageDataset(Dataset):

with the three mandatory methods (look at the documentation above)

def __init__(self, img_dir, ...):
def __len__(self):
def __getitem__(self, idx):

you can create an instance of the class and you could test your code, by verifying that the number of files is the intended one and that the method can fetch images - for example with the following lines:

trainset = CustomImageDataset(train_image_dir) 
print('N of loaded images: {}'.format(len(trainset)) 
first_image, first_label = trainset[0]

Most likely, you want that the init read all the files in RAM, so it is within the init that you will define the logic for exploring the paths and load the pictures. If I understood correctly what you want, the getitem function could defined in such a way that it returns two elements, the first being an image in the A folder, and the second output will be the related image of the B folder.

Afterwards, you would need only to instantiate the validation dataset, without the need to define a new class

valset = CustomImageDataset(valid_image_dir)

from this point, you have the logic for reading the data. Afterwards, you can let pytorch handle the batching of images through its own implementation of the dataloader, which you do not have to derive like we did before, but just to instantiate train_dataloader and valid_dataloader

from torch.utils.data import DataLoader
train_dataloader = DataLoader(trainset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valset, batch_size=64, shuffle=True)

Upvotes: 1

Related Questions