Ailsor
Ailsor

Reputation: 48

Keras Multiprocessing breaks validation accuracy

I am training a neural network with a large dataset and therefore I need to use multiple workers/multiprocessing to speed up the training.

Previously I was using the keras generator as it is and using fit generator with multiprocessing set to false and workers set to 16, however recently I had to use my own generator so I created my own flow_from_directory generator as below:

train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(image_size, image_size),
batch_size=training_batch_size,
class_mode='categorical') # set as training data

bal_gen = balanced_flow_from_directory(train_generator)

def balanced_flow_from_directory(flow_from_directory):
    for x, y in flow_from_directory:
         yield custom_balance(x, y)

However in the fit generator when I would keep workers > 1 and MultiProcessing to False, it would tell me my generator isn't threadsafe and so can't be used with workers > 1 and Multiprocessing set to False. When I would keep the workers > 1 and set MultiProcessing to True, the code runs but it gives me warnings like:

WARNING:tensorflow:Using a generator with use_multiprocessing=True and multiple workers may duplicate your data. Please consider using the tf.data.Dataset

Furthermore, the validation gives very weird outputs such as:

1661/1661 [==============================] - ETA: 0s - loss: 0.1420 - accuracy: 0.9662WARNING:tensorflow:Using a generator with use_multiprocessing=True and multiple workers may duplicate your data. Please consider using the tf.data.Dataset. 1661/1661 [==============================] - 475s 286ms/step - loss: 0.1420 - accuracy: 0.9662 - val_loss: 6.2723 - val_accuracy: 0.0108elines tf.data is recommended.

The validation accuracy is always very low and the val_loss is always high. Is there something I could do to fix this?


Update: I found a code to make the generator function threadsafe as follows:

import threading

class threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()

def threadsafe_generator(f):
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))

    return g

@threadsafe_generator
def balanced_flow_from_directory(flow_from_directory):
    for x, y in flow_from_directory:
         yield custom_balance(x, y)

Now I am able to use workers = 16 with Multiprocessing set to False as I used to use before I made my custom generator. However when I do this, it takes 30 minutes per epoch while earlier it used to take like 7 minutes.

When I use workers=16 along with multiprocessing set to true, it gives me the same issues as I got above when I set multiprocessing to true - namely the validation accuracy breaking.

Upvotes: 0

Views: 214

Answers (1)

Brylle Gomez
Brylle Gomez

Reputation: 11

Perhaps you should apply the same data balancing function to your validation data generator?

Upvotes: 1

Related Questions