Reputation: 48
I am training a neural network with a large dataset and therefore I need to use multiple workers/multiprocessing to speed up the training.
Previously I was using the keras generator as it is and using fit generator with multiprocessing set to false and workers set to 16, however recently I had to use my own generator so I created my own flow_from_directory generator as below:
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(image_size, image_size),
batch_size=training_batch_size,
class_mode='categorical') # set as training data
bal_gen = balanced_flow_from_directory(train_generator)
def balanced_flow_from_directory(flow_from_directory):
for x, y in flow_from_directory:
yield custom_balance(x, y)
However in the fit generator when I would keep workers > 1 and MultiProcessing to False, it would tell me my generator isn't threadsafe and so can't be used with workers > 1 and Multiprocessing set to False. When I would keep the workers > 1 and set MultiProcessing to True, the code runs but it gives me warnings like:
WARNING:tensorflow:Using a generator with
use_multiprocessing=True
and multiple workers may duplicate your data. Please consider using thetf.data.Dataset
Furthermore, the validation gives very weird outputs such as:
1661/1661 [==============================] - ETA: 0s - loss: 0.1420 - accuracy: 0.9662WARNING:tensorflow:Using a generator with
use_multiprocessing=True
and multiple workers may duplicate your data. Please consider using thetf.data.Dataset
. 1661/1661 [==============================] - 475s 286ms/step - loss: 0.1420 - accuracy: 0.9662 - val_loss: 6.2723 - val_accuracy: 0.0108elines tf.data is recommended.
The validation accuracy is always very low and the val_loss is always high. Is there something I could do to fix this?
Update: I found a code to make the generator function threadsafe as follows:
import threading
class threadsafe_iter:
"""
Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return self.it.__next__()
def threadsafe_generator(f):
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
@threadsafe_generator
def balanced_flow_from_directory(flow_from_directory):
for x, y in flow_from_directory:
yield custom_balance(x, y)
Now I am able to use workers = 16 with Multiprocessing set to False as I used to use before I made my custom generator. However when I do this, it takes 30 minutes per epoch while earlier it used to take like 7 minutes.
When I use workers=16 along with multiprocessing set to true, it gives me the same issues as I got above when I set multiprocessing to true - namely the validation accuracy breaking.
Upvotes: 0
Views: 214
Reputation: 11
Perhaps you should apply the same data balancing function to your validation data generator?
Upvotes: 1