Reputation: 1932
I am trying downlaod the data from the Oxford Flowers 102 dataset and split it into training, validation and test sets using the tfds APIs. Here is my code:
# Split numbers
train_split = 60
test_val_split = 20
splits = tfds.Split.ALL.subsplit([train_split,test_val_split, test_val_split])
# TODO: Create a training set, a validation set and a test set.
(training_set, validation_set, test_set), dataset_info = tfds.load('oxford_flowers102', split=splits, as_supervised=True, with_info=True)
Trouble is when I print out dataset_info
I get the following numbers for my test, train and validation sets
total_num_examples=8189,
splits={
'test': 6149,
'train': 1020,
'validation': 1020,
},
Question: How do I get the data to split into 6149 in the training set and 1020 in the test and validation sets?
Upvotes: 2
Views: 1616
Reputation: 1250
It seems to be a bug in the dataset itself. Especially because the total size of the dataset is 8189 and 6149 is not 60% of the total but 75%, so you didn't perform any splitting at all. They probably labeled the splits in the wrong way. Also, even when I try to load the dataset with the different ways described here ( https://github.com/tensorflow/datasets/blob/master/docs/splits.md ) I got the same wrong splitting.
An easy solution would be to just pass to the model the test set as training set and vice versa, but you will not have the percentage you want. Otherwise you can load the entire dataset (train+test+validation) and then split it by yourself.
df_all, summary = tfds.load('oxford_flowers102', split='train+test+validation', with_info=True)
# check if the dataset loaded truly contains everything
df_all_length = [i for i,_ in enumerate(df_all)][-1] + 1
print(df_all_length)
>>out: 8189 # length is fine
train_size = int(0.6 * df_all_length)
val_test_size = int(0.2 * df_all_length)
# split whole dataset
df_train = df_all.take(train_size)
df_test = df_all.skip(train_size)
df_valid = df_test.skip(val_test_size)
df_test = df_test.take(val_test_size)
df_train_length = [i for i,_ in enumerate(df_train)][-1] + 1
df_val_length = [i for i,_ in enumerate(df_val)][-1] + 1
df_test_length = [i for i,_ in enumerate(df_test)][-1] + 1
# check sizes
print('Train: ', df_train_length)
print('Validation :', df_valid_length)
print('Test :', df_test_length)
>>out: 4913 #(true 60% of 8189)
>>out: 1638 #(true 20% of 8189)
>>out: 1638
Upvotes: 5