Reputation: 1175
After a lot of research, it seems like there is no good way to properly stop and resume training using a Tensorflow 2 / Keras model. This is true whether you are using model.fit()
or using a custom training loop.
There seem to be 2 supported ways to save a model while training:
Save just the weights of the model, using model.save_weights()
or save_weights_only=True
with tf.keras.callbacks.ModelCheckpoint
. This seems to be preferred by most of the examples I've seen, however it has a number of major issues:
Save the entire model, optimizer, etc. using model.save()
or save_weights_only=False
. The optimizer state is saved (good) but the following issues remain:
The best workaround I've found is to use a custom training loop, manually saving the step. This fixes the tensorboard logging, and the learning rate schedule can be fixed by doing something like keras.backend.set_value(model.optimizer.iterations, step)
. However, since a full model save is off the table, the optimizer state is not preserved. I can see no way to save the state of the optimizer independently, at least without a lot of work. And messing with the LR schedule as I've done feels messy as well.
Am I missing something? How are people out there saving/resuming using this API?
Upvotes: 7
Views: 3298
Reputation: 61
Just use the callback function as
callback = tf.keras.callbacks.experimental.BackupAndRestore(
backup_dir="backup_directory")
Upvotes: 1
Reputation: 151
tf.keras.callbacks.experimental.BackupAndRestore
API for resuming training from interruptions has been added for tensorflow>=2.3
. It works great in my experience.
Reference: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/experimental/BackupAndRestore
Upvotes: 5
Reputation: 19776
You're right, there isn't builtin support for resumability - which is exactly what motivated me to create DeepTrain. It's like Pytorch Lightning (better and worse in different regards) for TensorFlow/Keras.
Why another library? Don't we have enough? You have nothing like this; if there was, I'd not build it. DeepTrain's tailored for the "babysitting approach" to training: train fewer models, but train them thoroughly. Closely monitor each stage to diagnose what's wrong and how to fix.
Inspiration came from my own use; I'd see "validation spikes" throughout a long epoch, and couldn't afford to pause as it'd restart the epoch or otherwise disrupt the train loop. And forget knowing which batch you were fitting, or how many remain.
How's it compare to Pytorch Lightning? Superior resumability and introspection, along unique train debug utilities - but Lightning fares better in other regards. I have a comprehensive list comparison in working, will post within a week.
Pytorch support coming? Maybe. If I convince the Lightning dev team to make up for its shortcomings relative to DeepTrain, then not - otherwise probably. In the meantime, you can explore the gallery of Examples.
Minimal example:
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from deeptrain import TrainGenerator, DataGenerator
ipt = Input((16,))
out = Dense(10, 'softmax')(ipt)
model = Model(ipt, out)
model.compile('adam', 'categorical_crossentropy')
dg = DataGenerator(data_path="data/train", labels_path="data/train/labels.npy")
vdg = DataGenerator(data_path="data/val", labels_path="data/val/labels.npy")
tg = TrainGenerator(model, dg, vdg, epochs=3, logs_dir="logs/")
tg.train()
You can KeyboardInterrupt
at any time, inspect the model, train state, data generator - and resume.
Upvotes: 6