tammy
tammy

Reputation: 43

To split the main data directory into Train/validation/test Set

I am working on X-ray image classification for which my data is stored in 1 directory and I need to divide it into train,validation and test set. I did manage to separate train and validation set using ImagedDataGenerator but am having troubles to separate the test set. Here's my code.

import split

# Path
Images = 'data_processed_cropped_32'
data_set = os.path.join(r'C:\Users\320067835\Desktop\Thesis\Data\png', Images)

#split.ratio('data_processed_cropped_32', output="output", seed=1337, ratio=(0.8, 0.1,0.1))

# Image size
img_width = 32
img_height = 32

# Data augmentation
data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale = 1/255, horizontal_flip = True,
                                                            rotation_range = 0,validation_split=0.2)

train_set = data_gen.flow_from_directory(data_set, target_size = (img_width, img_height), color_mode = 'grayscale',
                                        class_mode = 'categorical', batch_size = 32, interpolation = 'nearest',
                                        subset ='training')

validation_set = data_gen.flow_from_directory(data_set, target_size= (img_width,img_height), color_mode='grayscale',
                                              batch_size=32, class_mode='categorical', interpolation= 'nearest',
                                              subset='validation')
# Build a model
cnn = Sequential()

cnn.add(keras.Input(shape = (32,32,1)))
cnn.add(Conv2D(16,(3,3), padding = 'same', activation = 'relu', input_shape= (img_width,img_height,1)))
cnn.add(MaxPooling2D(2,2))
cnn.add(Conv2D(32,(3,3), padding = 'same',activation = 'relu', input_shape= (img_width, img_height,1)))
cnn.add(MaxPooling2D(2,2))

cnn.add(Flatten())

cnn.add(Dense(units = 100, activation = 'relu'))
cnn.add(Dense(units = 50, activation = 'relu'))
cnn.add(Dense(units=23, activation = 'softmax'))
cnn.summary()
cnn.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])

cnn.fit(train_set,validation_data = validation_set,epochs = 20)

I tired using the split-folder but it didn't work.I think mostly am not using it correctly also because I don't know how would I then access the 3 folders after splitting the data. Or is there any other method where i can split my test set ?

Upvotes: 2

Views: 852

Answers (1)

Gerry P
Gerry P

Reputation: 8092

I have had a need to do this often do I developed a thorough function to accomplish the splitting. It is rather lengthy because it does a lot of checks etc. The code is posted below.

import os
import shutil
from tqdm import tqdm
from sklearn.model_selection import train_test_split

def tr_te_val_split(s_dir, dest_dir, train_size, test_size): 
    if train_size <0 or train_size >1:
        print('*** Train size must be a float between 0.0 and 1.0, process terminated ***')
        return
    if test_size <0 or test_size >1:
        print('*** Test size must be a float between 0.0 and 1.0, process terminated ***')
        return
    if test_size + train_size >1:
        print ('*** The sum of the train size plus the test size must be <= 1, process terminating ***')
        return
    
    remainder= 1-train_size # percent available for test and validation
    test_size= test_size/remainder
    if os.path.isdir(dest_dir)==False:
        os.mkdir(dest_dir)
        print ('The dest_dir you specified ', dest_dir, ' does not exist, created it for you ')        
    dest_list=os.listdir(dest_dir) # list content of destination directory
    for d in ['train', 'test', 'valid']:
        d_path=os.path.join(dest_dir,d)
        if d not in dest_list:
            os.mkdir(d_path)  # create train, test and valid directories in the destination directory
        else: # check to see if there are any files in these directories
            d_list=os.listdir(d_path)
            if len(d_list) > 0:  # there are files or directories in d
                cycle=True
                print('*** WARNING***  there is content in ', d_path)
                while cycle:
                    ans=input(' enter D to delete content, C to continue and keep content or Q to Quit ')
                    if ans not in ['D', 'd', 'C', 'c', 'Q', 'q']:
                        print('your response ', ans, ' was not a  D, C or Q, try again')
                    else:
                        cycle=False
                        if ans in ['Q', 'q']:
                            print ('**** PROCESS TERMINATED BY USER ****')
                            return
                        else:
                            if ans in ['D', 'd']:
                                print(' Removing all files and sub directories in ', d_path)
                                for f in d_list:
                                    f_path=os.path.join (d_path,f)
                                    if os.path.isdir(f_path):                                        
                                        shutil.rmtree(f_path)                                        
                                    else:
                                        os.remove(f_path)
            
    class_list=os.listdir(s_dir)  # listof classes     
    for klass in tqdm(class_list): # iterate through the classes
        klass_path=os.path.join(s_dir, klass) # path to class directory
        f_list=os.listdir(klass_path) # get the list of file names
        ftrain, ftv= train_test_split(f_list, train_size=train_size, random_state=123 )
        ftest, fvalid= train_test_split(ftv, train_size= test_size, random_state=123 )        
        for d in ['train', 'test', 'valid']:
            d_path=os.path.join(dest_dir,d)
            d_class_path=os.path.join(d_path,klass)
            if os.path.isdir(d_class_path)==False:
                os.mkdir(d_class_path)
            if d=='train':
                fx=ftrain
            elif d=='test':
                fx=ftest
            else:
                fx=fvalid
            for f in fx:
                f_path=os.path.join(klass_path, f)
                d_f_path=os.path.join(d_class_path,f)
                shutil.copy(f_path, d_f_path)
    for d in ['train', 'test', 'valid']:
        file_count=0
        d_path=os.path.join(dest_dir, d)
        d_list=os.listdir(d_path)
        for klass in d_list:
            klass_path=os.path.join(d_path, klass)
            klass_list=os.listdir(klass_path)
            d_count=len(klass_list)
            file_count=file_count + d_count
            if d == 'train':
                tr_count=file_count
            elif d =='test':
                te_count=file_count
            else:
                tv_count=file_count
    print ('Process Completed ', tr_count, ' training files ', te_count, ' test files and ', tv_count, ' validation files were partitioned')

This function splits the files in the s_dir into train, test, and validation files stored in the dest_dir. s_dir is the full path to the directory containing the files to be split dest_dir is the full path to the destination directory. If it does not exist it is created. train_size is a float between 0.0 and 1.0 indicating the percentage of file to be allocated as training files test_size is a float between 0.0 and 1.0 indicating the percentage of file to be allocated as test files In the dest_dir three sub directories 'train', 'test' and 'valid' are created and used to store the training files, test files and validation files. If these sub directories already exist they are check for existing content.If content is found a notice is printed to that effect. The user is then prompted to enter 'D' to delete the content, 'Q' to terminate program execution or 'C' to continue. If 'C' is selected the content is not removed however files may be over written if any existing files have the same file name as the new files being added to the sub directory. Note if the test, train and valid directories exist and have content, and the user elects 'c' to continue sub directories and files from the s_dir are appended to the content of the test, train and valid subdirectories in the dest_dir This function utlilizes tqdm and sklearn which must be installed in your working environment¶

Upvotes: 1

Related Questions