Reputation: 6197
I am training with keras model.fit
, and the data comes from tf.records, loaded into a tf.data object, which uses .shuffle
to shuffle the data. I am also using callbacks.ModelCheckpoint
to save the model every x
number of steps/batches.
Sometimes my cloud instance disconnects or crashes before an epoch is finished, but the model at y
step is saved into my drive.
I would like to finish training over the data in that epoch (I have very long epochs), before training another epoch, so each that each data example is trained over once per epoch.
Is there a way to get the original order of the data, and the place within the data where model was last saved?
It looks like you can set a specific order in .shuffle by setting the seed. However, shuffling only occurs in the buffer, so I am not 100% sure if setting the seed will perfectly reproduce the order. Also, I am not sure how that will work with reshuffle_each_iteration
. Is a different seed used after each epoch? If so, I guess a work around is train only 1 epoch at a time, with a specified seed for each epoch.
Even if I do get a replica of the training order, I'm not sure how to find where in the order was the model last saved, and then to start training from that point. One idea I have to get to the order, is iterate through the dataset manually until I reach it. Although I'm not sure if model.fit()
would continue from this order, or start all over. F
For getting the step/batch number from where the model was last saved, I could probably log this somewhere.
These solutions seem like rough workarounds, and I am wondering if there's some features in Keras that I may be overlooking to help with this.
Upvotes: 2
Views: 1691
Reputation: 16906
There seem to be no keras build in way to do this, but please correct me if I am wrong.
Dataset.shuffle
internally uses the initial seed value to generate seeds to be used for reshuffling during iterations when reshuffle_each_iteration=True
. So re-create the same order for a particular epoch and continue the training of the epoch at that particular batch we have to re-create the Dataset with same seed and move the dataset iterator to the same epoch and same batch.
For debugging and making sure the epochs and batches are generated in same order, we will need a way to print how the data points are picked up in each epoch-batch. This is tricky in kears, so I will for debugging purpose use the regression problem and have ground truth as sequential numbers. Then I can have a custom loss where I can print ground truth and make user the order is correct.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import keras.backend as K
# Data
x_train = np.random.randn(15, 10).astype("float32")
y_train = np.arange(15).astype("float32")
# Custom MSE looss just to track the order in which data is picked up
def my_mse(y_true, y_pred):
tf.print(tf.keras.backend.flatten(y_true))
loss = K.square(y_pred - y_true)
loss = K.sum(loss, axis=1)
return loss
# Model
def get_model():
inputs = keras.Input(shape=(10))
outputs = layers.Dense(1, activation="linear")(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer="rmsprop",
loss=my_mse,
)
return model
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8)
epochs = 2
print ("Runs 1")
for e in range(epochs):
for i, (x, y) in enumerate(train_dataset):
print (e, i, y)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8)
print ("Runs 2")
for e in range(epochs):
for i, (x, y) in enumerate(train_dataset):
print (e, i, y)
Output:
Runs 1
0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32)
1 tf.Tensor([ 6. 11. 10. 14. 9. 12. 13.], shape=(7,), dtype=float32)
2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32)
3 tf.Tensor([13. 10. 0. 14. 6. 11. 12.], shape=(7,), dtype=float32)
4 tf.Tensor([ 0. 1. 5. 6. 9. 3. 7. 14.], shape=(8,), dtype=float32)
5 tf.Tensor([13. 8. 4. 10. 2. 12. 11.], shape=(7,), dtype=float32)
Runs 2
0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32)
1 tf.Tensor([ 6. 11. 10. 14. 9. 12. 13.], shape=(7,), dtype=float32)
2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32)
3 tf.Tensor([13. 10. 0. 14. 6. 11. 12.], shape=(7,), dtype=float32)
4 tf.Tensor([ 0. 1. 5. 6. 9. 3. 7. 14.], shape=(8,), dtype=float32)
5 tf.Tensor([13. 8. 4. 10. 2. 12. 11.], shape=(7,), dtype=float32)
Yes with the seed the order is reproduced.
Now let write a method to forward the dataset to a certain epoch and batch combination
def forward(dataset, n=None):
if not n:
return dataset
i = 0
while True:
for _ in dataset:
i += 1
if i == n:
return dataset
Lets run it normally and observe the order
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None)
model = get_model()
model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False)
Output:
[7 3 6 10]
[11 0 1 2]
[8 14 9 13]
[12 5 4]
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]
Let forward our dataset to 4th iteration and run the training
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4)
model = get_model()
model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False)
Output:
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]
Nice, now we know how to forward the dataset correctly. Lets now write callback to track the current iteration number:
Now we need to identify epoch and batch combination at which the model is check pointed. If we have this information we can load the last check pointed model and forward our dataset to its batch and epoch combination and continue the training. We will do this using the call backs
class MyCustomCallback(tf.keras.callbacks.ModelCheckpoint, keras.callbacks.Callback):
def __init__(self, the_id=0, **args):
self.the_id = the_id
self.epoch = 0
super().__init__(**args)
def _save_model(self, epoch, logs):
logs['the_id'] = self.the_id
super()._save_model(epoch, logs)
def on_batch_end(self, batch, logs={}):
self.the_id += 1
super().on_batch_end(batch, logs)
checkpoint_filepath = 'checkpoint-{the_id}'
model_checkpoint_callback = MyCustomCallback(
filepath=checkpoint_filepath,
save_freq=2,
save_best_only=False)
model = get_model()
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None)
model.fit(train_dataset, epochs=5, verbose=0, callbacks=[model_checkpoint_callback], workers=4, shuffle=False)
Output:
[7 3 6 10]
[11 0 1 2]
[8 14 9 13]
[12 5 4]
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]
We are check pointing for every two batches. So lets assume it crashes and the last checkpoint is checkpoint-4
. We can load this model and forward our dataset to 4 and continue training.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4)
model = get_model()
model.fit(train_dataset, epochs=2, verbose=0, workers=4, shuffle=False)
Output:
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]
Upvotes: 2
Reputation: 198
I suppose you want to restore shuffle order to avoid repetition of some samples inside this epoch.
According to shuffle description during not finished epoch you model had access only to the first current_step_number + shuffle_buffer_size samples from you dataset.
So when you restore you training if you know how many steps were processed, you can just skip this steps + skip shuffle_buffer_size steps and you training will be continued on following samples, which was not observed yet inside current epoch.
Note that some random shuffle_buffer_size samples from first part of dataset will not be observed at all during this epoch. As you say your epoch is very long, so, probably you have a lot of data, so losing shuffle_buffer_size samples should not be problem for you.
So during saving checkpoint also save step number, then after loading checkpoint create dataset copy with skipped steps (using dataset.skip), then use model.fit with this smaller dataset for one epoch (to finish current epoch), then continue your training in usual way.
Upvotes: 0