Reputation: 11
Hey I've been seatching around the web looking for help on importing a custom image set, but every good tutorial seems to just MNIST which is fine, but I dont know how to translate code to a custom set.
I've got a folder structure like this:
SET:
|-->Training
-----|-->A
---------|-->8000 items
-----|-->B
----------|-->8000 items
|-->Validation
-----|-->A
----------|-->600 items
-----|-->B
----------|-->600 items
I want to train a GAN on the set of 8000 input images in Training set A to hopefully learn to mimic Training set B
I've been having no luck understand all the self inheritance from MNIST and how to use that with a custom set
Upvotes: 1
Views: 902
Reputation: 1
Here is how to make elements of your dataloader customized.
Let's start by how default torch dataloader looks like:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
Let's first look how to create your "dataset". My example follows the logic in this tutorial but for your dataset structure (with A and B image folder): https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
from torch.utils.data import Dataset
class GANDataset(Dataset):
def __init__(self, csv_file, sampler, transform, mode='train'):
"""
Arguments:
csv_file: Path to the csv file,
in your case, let's say it has two columns "A" and "B" and, in the case of your training set, with 8000 rows. For example, item 0 at column A is the full path to your first training input image,
and item 0 at column B is the full path to your first training ground truth image.
transform: transforms to be applied on a sample. I will show you a minimal transform composite later.
sampler: you need some sort of sampler, which can be a costum one or one of torch samplers. I will show how a simple sampler looks like.
mode: 'train' or 'inference'
"""
self.dataset = read_data_from_csv(csv_file)
self.dataset_indices = sampler.return_indices(self.dataset, mode)
self.mode = mode
self.transform = transform
def __len__(self):
return len(self.dataset_indices)
def __getitem__(self, idx):
img_A, img_B = self.read_images(idx)
to_move_on = False
if self.mode == 'train':
if img_A is not None and img_B is not None:
to_move_on = True
if self.mode == 'inference':
if img_A is not None:
to_move_on = True
if to_move_on: #in training mode both input and target should be not None, in inference mode only the input should not be None.
data_sample = copy.deepcopy(self.dataset[idx])
data_sample['img_B'] = img_B
data_sample['img_A'] = img_A
data_sample = self.transform(data_sample)
return data_sample
else: #if you fail for any reason to read this sample you can pick up another idx so that your training process doesn't stop.
another_idx = random.randint(0, len(self.dataset_indices) - 1)
return self[self.dataset_indices[another_idx]]
def read_images(self, idx):
img_A = None
img_B = None
try:
if self.dataset[idx]['A'] is not None:
img_A = cv2.imread(self.dataset[idx]['A'])
except:
img_A = None
try:
if self.dataset[idx]['B'] is not None:
img_B = cv2.imread(self.dataset[idx]['B'])
except:
img_B = None
return img_A, img_B
The function "read_data_from_csv" should be defined by you to build a dict-like dataset from your csv file. It can be as simple as this:
def read_data_from_csv(csv_file):
df= pd.read_csv(csv_file)
dataset = []
for index, row in df.iterrows():
if not pd.isnull(row['A']) and row['A'] is not None:
A_file = row['A'])
else:
A_file = None
if not pd.isnull(row['B']) and row['B'] is not None:
B_file = row['B']
else:
B_file = None
if A_file is not None or B_file is not None:
new_data_item = {'A': A_file, 'B': B_file,
'img_A': None, 'img_B': None}
dataset.append(new_data_item)
return dataset
Let's look at a simple sampler with the most basic possibilit to sample, which should be enough for a GAN model. Look here for more: https://github.com/pytorch/pytorch/blob/main/torch/utils/data/sampler.py
class NaiveSampler(object):
def return_indices(self, dataset, mode):
dataset_indices = [idx for idx in range(len(dataset))]
if mode == 'train':
random.shuffle(dataset_indices)
return dataset_indices
Let's look at basic transforms. This example first decides whether a data augmentation should happen or not, in my case I want the sample to be flipped by a certain probability (part of your training hyper parameters) and then some other essential changes (like normalization, standardization and moving to torch tensor from numpy array). For mode details look at this tutorial: https://pytorch.org/vision/stable/transforms.html
transform=transforms.Compose([FlipAugmentation(hflip_probability, vflip_probability),
SampleGenerator(mean_A, std_A, mean_B, std_B)])
Where, my augmentation class could look like this:
class FlipAugmentation(object):
def __init__(self, hflip_probability, vflip_probability):
augmentation_list = []
self.sequence_augmentation = lambda x: x
if hflip_probability > 0:
augmentation_list.append(HorizontalFlip(hflip_probability))
if vflip_probability > 0:
augmentation_list.append(VerticalFlip(vflip_probability))
if len(augmentation_list) > 0:
def compose(g, f):
return lambda x: g(f(x))
self.sequence_augmentation = reduce(compose, augmentation_list, lambda x: x)
def __call__(self, sample):
return self.sequence_augmentation(sample)
class HorizontalFlip(object):
def __init__(self, hflip_probability):
self.hflip_probability = hflip_probability
def __call__(self, sample):
if np.random.rand() < self.hflip_probability:
returned_sample = copy.deepcopy(sample)
img_A = returned_sample['img_A']
img_A = img_A[:, ::-1, :]
returned_sample['img_A'] = img_A
img_B = returned_sample['img_B']
img_B = img_B[:, ::-1, :]
returned_sample['img_B'] = img_B
return returned_sample
else:
return sample
class VecrticalFlip(object):
def __init__(self, vflip_probability):
self.vflip_probability = vflip_probability
def __call__(self, sample):
if np.random.rand() < self.vflip_probability:
returned_sample = copy.deepcopy(sample)
img_A = returned_sample['img_A']
img_A = img_A[::-1, :, :]
returned_sample['img_A'] = img_A
img_B = returned_sample['img_B']
img_B = img_B[::-1,:, :]
returned_sample['img_B'] = img_B
return returned_sample
else:
return sample
And my sample generator is as simple as this, which standardizes input and target images:
class SampleGenerator(object):
def __init__(self, mean_A, std_A, mean_B, std_B):
self.normalize_means_stds = [mean_A, std_A, mean_B, std_B]
def __call__(self, sample):
img_A, img_B = sample['img_A'], sample['img_B']
sample_called = dict(sample)
img_A = (((img_A.astype(np.float32) / 255.0) - self.normalize_means_stds[0]) / self.normalize_means_stds[1])
sample_called['img_A'] = torch.from_numpy(img_A)
if img_B is not None:
img_B = (((img_B.astype(np.float32) / 255.0) - self.normalize_means_stds[2]) / self.normalize_means_stds[3])
sample_called['img_B'] = torch.from_numpy(img_B)
return sample_called
You finally need a batch sampler too, which can again be as simple as this (for more details look at here: https://medium.com/@haleema.ramzan/how-to-build-a-custom-batch-sampler-in-pytorch-ce04161583ee#:~:text=from%20torch.utils.data.sampler%20import%20Sampler-,class%20CustomBatchSampler(Sampler)%3A,-r%22%22%22Yield%20a%20mini%2Dbatch%20of%20indices.%20The):
from torch.utils.data.sampler import Sampler
class CustomBatchSampler(Sampler):
def __init__(self, dataset, batch_size):
self.batch_size = batch_size
self.data = data
index_list = self.data.dataset_indices
number_batches = self.__len__()
batch_list = []
for start in range(0, number_batches * self.batch_size, self.batch_size):
batch = []
for b in range(start, start + self.batch_size):
batch.append(index_list[b % len(index_list)])
batch_list.append(batch)
self.batch_list = batch_list
def __iter__(self):
for batch in self.batch_list:
yield batch
def __len__(self):
return len(self.data) // self.batch_size
It has now all come together and you can create your dataloader:
csv_file = 'path/to/your/file.csv'
sampler = NaiveSampler()
transform = transforms.Compose([FlipAugmentation(0.3, 0.3),
SampleGenerator(0, 1.0, 0, 1.0)])
train_dataset = GANDataset(csv_file, sampler, transform, mode='train')
batch_sampler = CustomBatchSampler(train_dataset, batch_size=16)
train_dataloader = DataLoader(train_dataset, num_workers=0,
batch_sampler=batch_sampler,
pin_memory=torch.cuda.is_available())
Upvotes: 0
Reputation: 176
You need to read your image files with a class that derives from the torch.utils.data.Dataset
class, in order to have your custom dataset
You can follow this part of the documentation to have a basic example of how to populate a custom Dataset.
So, After you define
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
with the three mandatory methods (look at the documentation above)
def __init__(self, img_dir, ...):
def __len__(self):
def __getitem__(self, idx):
you can create an instance of the class and you could test your code, by verifying that the number of files is the intended one and that the method can fetch images - for example with the following lines:
trainset = CustomImageDataset(train_image_dir)
print('N of loaded images: {}'.format(len(trainset))
first_image, first_label = trainset[0]
Most likely, you want that the init read all the files in RAM, so it is within the init that you will define the logic for exploring the paths and load the pictures. If I understood correctly what you want, the getitem function could defined in such a way that it returns two elements, the first being an image in the A folder, and the second output will be the related image of the B folder.
Afterwards, you would need only to instantiate the validation dataset, without the need to define a new class
valset = CustomImageDataset(valid_image_dir)
from this point, you have the logic for reading the data. Afterwards, you can let pytorch handle the batching of images through its own implementation of the dataloader, which you do not have to derive like we did before, but just to instantiate train_dataloader
and valid_dataloader
from torch.utils.data import DataLoader
train_dataloader = DataLoader(trainset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valset, batch_size=64, shuffle=True)
Upvotes: 1