Reputation: 23
I'm using a non-torchvision dataset and I have extracted it with the ImageFolder method. I'm trying to split the dataset into 20% validation set and 80% training set. I can only find this method (random_split) from PyTorch library which allows splitting dataset. However, this is random every time. I'm wondering is there a way to split the dataset with a specific amount in the PyTorch library?
This is my code for extracting the dataset and split it randomly.
transformations = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
TrafficSignSet = datasets.ImageFolder(root='./train/', transform=transformations)
####### split data
train_size = int(0.8 * len(TrafficSignSet))
test_size = len(TrafficSignSet) - train_size
train_dataset_split, test_dataset_split = torch.utils.data.random_split(TrafficSignSet, [train_size, test_size])
#######put into a Dataloader
train_dataset = torch.utils.data.DataLoader(train_dataset_split, batch_size=32, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test_dataset_split, batch_size=32, shuffle=True)
Upvotes: 1
Views: 1622
Reputation: 114936
If you look "under the hood" of random_split
you'll see it uses torch.utils.data.Subset
to do the actual splitting. You can do so yourself with fixed indices:
import random
indices = list(range(len(TrafficSignSet))
random.seed(310) # fix the seed so the shuffle will be the same everytime
random.shuffle(indices)
train_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[:train_size])
val_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[train_size:])
Upvotes: 2