Reputation: 303
I am training a GPT2 text generation model in TensorFlow and am performing a single epoch across my text corpus. My question is, how can I save my model every, say, 10 steps or so? My model abruptly stopped training on the 100th step with only another 20 to go....oooof.
I'm aware of the Model_Checkpoint() callback, but it doesn't appear as though I can replace steps
for epoch
in the save_freq
parameter.
tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch', **kwargs)
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
Upvotes: 0
Views: 719
Reputation: 1056
Set save_freq = 1
. This should save every step. I would not recommend this because it will spend much time on the i/o of the save and slow your training down.
Upvotes: 1