Reputation: 1713
How can I make Keras Models fit
method execute a generator in the main thread? From the docs, it looks like that setting workers=0 would execute the code in the main thread.
workers Integer. Used for generator or keras.utils.Sequence input only. Maximum number of processes to spin up when using process-based threading. If unspecified, workers will default to 1. If 0, will execute the generator on the main thread.
However when I do:
import tensorflow as tf
import threading
model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
model.compile(loss = "mse", optimizer = "adam")
def gen ():
for i in range(100):
print(threading.current_thread())
yield (tf.random.normal(shape=(100,1)), tf.random.normal(shape = (100,)))
model.fit(gen(), epochs = 1, workers = 0, verbose = 0, steps_per_epoch = 3)
I get
<_MainThread(MainThread, started 140516450817920)>
<_DummyThread(Dummy-5, started daemon 140514709206784)>
<_DummyThread(Dummy-4, started daemon 140514717599488)>
<tensorflow.python.keras.callbacks.History at 0x7fcc1e8a8d68>
Which I interpret as only the first step in the iterator has been executed in the main thread.
In my use case this is problematic because I need that the code inside the generator to always be executed in the main thread otherwise the program crashes.
Upvotes: 3
Views: 495
Reputation: 19250
It seems like @Ena was on the right track. The following code runs each iteration of the generator on the main thread. workers
must be set to 1. If it is set to 0, then the iterations are not on the main thread.
import tensorflow as tf
import threading
def gen():
for i in range(10):
current_thread = threading.current_thread()
is_main = current_thread is threading.main_thread()
print(f"is main: {is_main} | thread.name: {current_thread.name}")
yield (tf.random.normal(shape=(100, 1)), tf.random.normal(shape=(100,)))
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(
loss="mse",
optimizer="adam",
)
model.fit(gen(), workers=1, use_multiprocessing=True, verbose=2)
This is the output
is main: True | thread.name: MainThread
WARNING:tensorflow:multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
is main: True | thread.name: Thread-31
10/10 - 1s - loss: 4.1511
<tensorflow.python.keras.callbacks.History at 0x7fcf780ef1d0>
It seems like any data that is passed to a model.fit
call is converted to a tf.data.Dataset
. See the relevant code in model.fit
. And I have a hunch that tf.data.Dataset
is getting data in separate threads.
There is also a warning in the output about using tf.data
instead of multiprocessing. If one uses tf.data.Dataset
, however, the generator runs in a separate thread.
dset = tf.data.Dataset.from_generator(
gen,
output_signature=(
tf.TensorSpec(shape=(100, 1), dtype=tf.float32),
tf.TensorSpec(shape=(100,), dtype=tf.float32),
),
)
for _ in dset:
pass
The output is
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-4
is main: False | thread.name: Dummy-4
is main: False | thread.name: Dummy-4
is main: False | thread.name: Dummy-4
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-8
is main: False | thread.name: Dummy-4
Another option is to implement a custom training loop and bypass .fit()
entirely. See https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#train_the_model for an example.
Upvotes: 2
Reputation: 16737
I don't know about how to resolve this issue inside Keras, so I decided to implement a workaround solution for you that can enforce any system (including Keras) to run generator inside main thread. It is not to short, but works!
The idea is next - inside main thread you run your generator that feeds results into synchronized queue. Other thread that runs model fitting reads from this queue. Queue is filled by main thread lazily, meaning that whenever queue is full main thread stops generation waiting for queue slot to be freed.
Next example creates a queue of maximum size 10, this value can be seen inside code queue.Queue(10)
. You can tweak this param to whatever you want. The bigger it is the more values in advance are created by main thread. Bigger value is better because you'll have pre-cached values computed ahead. If you don't want any values ahead then set it to 1 (queue.Queue(1)
), but this may slow down your training.
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import threading, queue, time
model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
model.compile(loss = "mse", optimizer = "adam")
def gen():
for i in range(100):
print(threading.current_thread())
yield (tf.random.normal(shape=(100,1)), tf.random.normal(shape = (100,)))
finish = False
q = queue.Queue(10)
def main_gen(g):
global finish
for e in g:
while True:
try:
q.put(e, timeout = 0.01)
break
except queue.Full:
if finish:
break
if finish:
break
else:
q.put('__FINISH__')
def thread_fit():
global finish
def thread_gen():
while True:
e = q.get()
if e == '__FINISH__':
break
yield e
model.fit(thread_gen(), epochs = 1, workers = 0, verbose = 0, steps_per_epoch = 3)
finish = True
t0 = threading.Thread(target = thread_fit)
t0.start()
main_gen(gen())
t0.join()
Output:
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
<_MainThread(MainThread, started 13540)>
Upvotes: 2
Reputation: 126
Unfortunately, I still can't comment, so I will write it down here. If this answer is not helpful, please comment, so I can delete it.
Have you tried to use use_multiprocessing=True
in model.fit()
? Doc: https://keras.io/api/models/model_training_apis/
Upvotes: 0