Cici
Cici

Reputation: 23

How to split a dataset into a custom training set and a custom validation set with pytorch?

I'm using a non-torchvision dataset and I have extracted it with the ImageFolder method. I'm trying to split the dataset into 20% validation set and 80% training set. I can only find this method (random_split) from PyTorch library which allows splitting dataset. However, this is random every time. I'm wondering is there a way to split the dataset with a specific amount in the PyTorch library?

This is my code for extracting the dataset and split it randomly.

transformations = transforms.Compose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

TrafficSignSet = datasets.ImageFolder(root='./train/', transform=transformations)

####### split data
train_size = int(0.8 * len(TrafficSignSet))
test_size = len(TrafficSignSet) - train_size
train_dataset_split, test_dataset_split = torch.utils.data.random_split(TrafficSignSet, [train_size, test_size])

#######put into a Dataloader
train_dataset = torch.utils.data.DataLoader(train_dataset_split, batch_size=32, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test_dataset_split, batch_size=32, shuffle=True)

Upvotes: 1

Views: 1622

Answers (1)

Shai
Shai

Reputation: 114936

If you look "under the hood" of random_split you'll see it uses torch.utils.data.Subset to do the actual splitting. You can do so yourself with fixed indices:

import random

indices = list(range(len(TrafficSignSet))
random.seed(310)  # fix the seed so the shuffle will be the same everytime
random.shuffle(indices)
train_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[:train_size])
val_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[train_size:])

Upvotes: 2

Related Questions