Reputation: 26108
How to use different data augmentation (transforms) for different Subset
s in PyTorch?
For instance:
train, test = torch.utils.data.random_split(dataset, [80000, 2000])
train
and test
will have the same transforms as dataset
. How to use custom transforms for these subsets?
Upvotes: 13
Views: 9463
Reputation: 508
you can use a custom collate_fn
for every subset.
I've use it in object detection with a custom dataset, such that every sample is a dictionary that contains the image and the metadata:
def collate_fn_transform(transform):
def collate_fn(batch):
for sample in batch:
transformed = transform(image=sample['image'], bboxes=sample['boxes'],
keypoints=sample['keypoints'], labels=sample['labels'])
sample['image'] = transformed['image']
sample['boxes'] = torch.tensor(transformed['bboxes'], dtype=torch.float32)
sample['keypoints'] = torch.tensor(transformed['keypoints'], dtype=torch.float32).unsqueeze(0)
return batch
return collate_fn
indices = torch.randperm(len(dataset))
train_set = torch.utils.data.Subset(dataset, indices=indices[:train_size])
train_transform = A.Compose([...])
val_set = torch.utils.data.Subset(dataset, indices=indices[train_size:])
val_transform = A.Compose([...])
loaders = {
'train': torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,
collate_fn=collate_fn_transform(train_transform),
num_workers=4, pin_memory=True),
'val': torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False,
collate_fn=collate_fn_transform(val_transform))
}
Upvotes: 0
Reputation: 9806
This is what I use (taken from here):
import torch
from torch.utils.data import Dataset, TensorDataset, random_split
from torchvision import transforms
class DatasetFromSubset(Dataset):
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
x, y = self.subset[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.subset)
Here's an example:
init_dataset = TensorDataset(
torch.randn(100, 3, 24, 24),
torch.randint(0, 10, (100,))
)
lengths = [int(len(init_dataset)*0.8), int(len(init_dataset)*0.2)]
train_subset, test_subset = random_split(init_dataset, lengths)
train_dataset = DatasetFromSubset(
train_set, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
)
test_dataset = DatasetFromSubset(
test_set, transform=transforms.Normalize((0., 0., 0.), (0.5, 0.5, 0.5))
)
Upvotes: 6
Reputation: 629
I've given up and copied my own Subset (almost identical to pytorch). I keep the transform in the Subset (not the parent).
class Subset(Dataset):
r"""
Subset of a dataset at specified indices.
Arguments:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices, transform):
self.dataset = dataset
self.indices = indices
self.transform = transform
def __getitem__(self, idx):
im, labels = self.dataset[self.indices[idx]]
return self.transform(im), labels
def __len__(self):
return len(self.indices)
you'll also have to write your own split funciton
Upvotes: 4
Reputation: 26108
My current solution is not very elegant, but works:
from copy import copy
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
train_dataset.dataset = copy(full_dataset)
test_dataset.dataset.transform = transforms.Compose([
transforms.Resize(img_resolution),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_dataset.dataset.transform = transforms.Compose([
transforms.RandomResizedCrop(img_resolution[0]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
Basically, I'm defining a new dataset (which is a copy of the original dataset) for one of the splits, and then I define a custom transform for each split.
Note: train_dataset.dataset.transform
works since I'm using an ImageFolder
dataset, which uses the .tranform
attribute to perform the transforms.
If anybody knows a better solution, please share with us!
Upvotes: 13