Daniel Falbel
Daniel Falbel

Reputation: 1713

Keras fit with generator function always execute in the main thread

How can I make Keras Models fit method execute a generator in the main thread? From the docs, it looks like that setting workers=0 would execute the code in the main thread.

workers Integer. Used for generator or keras.utils.Sequence input only. Maximum number of processes to spin up when using process-based threading. If unspecified, workers will default to 1. If 0, will execute the generator on the main thread.

However when I do:

import tensorflow as tf
import threading
model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
model.compile(loss = "mse", optimizer = "adam")

def gen ():
  for i in range(100):
    print(threading.current_thread())
    yield (tf.random.normal(shape=(100,1)), tf.random.normal(shape = (100,)))

model.fit(gen(), epochs = 1, workers = 0, verbose = 0, steps_per_epoch = 3)

I get

<_MainThread(MainThread, started 140516450817920)>
<_DummyThread(Dummy-5, started daemon 140514709206784)>
<_DummyThread(Dummy-4, started daemon 140514717599488)>
<tensorflow.python.keras.callbacks.History at 0x7fcc1e8a8d68>

Which I interpret as only the first step in the iterator has been executed in the main thread.

In my use case this is problematic because I need that the code inside the generator to always be executed in the main thread otherwise the program crashes.

Upvotes: 3

Views: 495

Answers (3)

jkr
jkr

Reputation: 19250

It seems like @Ena was on the right track. The following code runs each iteration of the generator on the main thread. workers must be set to 1. If it is set to 0, then the iterations are not on the main thread.

import tensorflow as tf
import threading


def gen():
    for i in range(10):
        current_thread = threading.current_thread()
        is_main = current_thread is threading.main_thread()
        print(f"is main: {is_main} | thread.name: {current_thread.name}")
        yield (tf.random.normal(shape=(100, 1)), tf.random.normal(shape=(100,)))


model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(
    loss="mse",
    optimizer="adam",
)
model.fit(gen(), workers=1, use_multiprocessing=True, verbose=2)

This is the output

is main: True | thread.name: MainThread
WARNING:tensorflow:multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
10/10 - 1s - loss: 4.1511

<tensorflow.python.keras.callbacks.History at 0x7fcf780ef1d0>

It seems like any data that is passed to a model.fit call is converted to a tf.data.Dataset. See the relevant code in model.fit. And I have a hunch that tf.data.Dataset is getting data in separate threads.

There is also a warning in the output about using tf.data instead of multiprocessing. If one uses tf.data.Dataset, however, the generator runs in a separate thread.

dset = tf.data.Dataset.from_generator(
    gen,
    output_signature=(
        tf.TensorSpec(shape=(100, 1), dtype=tf.float32),
        tf.TensorSpec(shape=(100,), dtype=tf.float32),
    ),
)

for _ in dset:
    pass

The output is

is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-4
is main: False | thread.name: Dummy-4
is main: False | thread.name: Dummy-4
is main: False | thread.name: Dummy-4
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-4

Another option is to implement a custom training loop and bypass .fit() entirely. See https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#train_the_model for an example.

Upvotes: 2

Arty
Arty

Reputation: 16737

I don't know about how to resolve this issue inside Keras, so I decided to implement a workaround solution for you that can enforce any system (including Keras) to run generator inside main thread. It is not to short, but works!

The idea is next - inside main thread you run your generator that feeds results into synchronized queue. Other thread that runs model fitting reads from this queue. Queue is filled by main thread lazily, meaning that whenever queue is full main thread stops generation waiting for queue slot to be freed.

Next example creates a queue of maximum size 10, this value can be seen inside code queue.Queue(10). You can tweak this param to whatever you want. The bigger it is the more values in advance are created by main thread. Bigger value is better because you'll have pre-cached values computed ahead. If you don't want any values ahead then set it to 1 (queue.Queue(1)), but this may slow down your training.

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import threading, queue, time
model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
model.compile(loss = "mse", optimizer = "adam")

def gen():
    for i in range(100):
        print(threading.current_thread())
        yield (tf.random.normal(shape=(100,1)), tf.random.normal(shape = (100,)))
        
finish = False
q = queue.Queue(10)
def main_gen(g):
    global finish
    for e in g:
        while True:
            try:
                q.put(e, timeout = 0.01)
                break
            except queue.Full:
                if finish:
                    break
        if finish:
            break
    else:
        q.put('__FINISH__')
        
def thread_fit():
    global finish
    def thread_gen():
        while True:
            e = q.get()
            if e == '__FINISH__':
                break
            yield e
    model.fit(thread_gen(), epochs = 1, workers = 0, verbose = 0, steps_per_epoch = 3)
    finish = True

t0 = threading.Thread(target = thread_fit)
t0.start()

main_gen(gen())

t0.join()

Output:

<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>

Upvotes: 2

Ena
Ena

Reputation: 126

Unfortunately, I still can't comment, so I will write it down here. If this answer is not helpful, please comment, so I can delete it.

Have you tried to use use_multiprocessing=True in model.fit()? Doc: https://keras.io/api/models/model_training_apis/

Upvotes: 0

Related Questions