SantoshGupta7
SantoshGupta7

Reputation: 6197

Keras training with shuffled tf.data: if training is interrupted, how to continue training at last data iteration/order of last saved checkpoint

I am training with keras model.fit, and the data comes from tf.records, loaded into a tf.data object, which uses .shuffle to shuffle the data. I am also using callbacks.ModelCheckpoint to save the model every x number of steps/batches.

Sometimes my cloud instance disconnects or crashes before an epoch is finished, but the model at y step is saved into my drive.

I would like to finish training over the data in that epoch (I have very long epochs), before training another epoch, so each that each data example is trained over once per epoch.

Is there a way to get the original order of the data, and the place within the data where model was last saved?

What I have found so far

It looks like you can set a specific order in .shuffle by setting the seed. However, shuffling only occurs in the buffer, so I am not 100% sure if setting the seed will perfectly reproduce the order. Also, I am not sure how that will work with reshuffle_each_iteration. Is a different seed used after each epoch? If so, I guess a work around is train only 1 epoch at a time, with a specified seed for each epoch.

Even if I do get a replica of the training order, I'm not sure how to find where in the order was the model last saved, and then to start training from that point. One idea I have to get to the order, is iterate through the dataset manually until I reach it. Although I'm not sure if model.fit() would continue from this order, or start all over. F

For getting the step/batch number from where the model was last saved, I could probably log this somewhere.

These solutions seem like rough workarounds, and I am wondering if there's some features in Keras that I may be overlooking to help with this.

Upvotes: 2

Views: 1691

Answers (2)

mujjiga
mujjiga

Reputation: 16906

There seem to be no keras build in way to do this, but please correct me if I am wrong.

My Approach

Dataset.shuffle internally uses the initial seed value to generate seeds to be used for reshuffling during iterations when reshuffle_each_iteration=True. So re-create the same order for a particular epoch and continue the training of the epoch at that particular batch we have to re-create the Dataset with same seed and move the dataset iterator to the same epoch and same batch.

Debugging

For debugging and making sure the epochs and batches are generated in same order, we will need a way to print how the data points are picked up in each epoch-batch. This is tricky in kears, so I will for debugging purpose use the regression problem and have ground truth as sequential numbers. Then I can have a custom loss where I can print ground truth and make user the order is correct.

Model and Data

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import keras.backend as K


# Data
x_train = np.random.randn(15, 10).astype("float32")
y_train = np.arange(15).astype("float32")

# Custom MSE looss just to track the order in which data is picked up
def my_mse(y_true, y_pred):
    tf.print(tf.keras.backend.flatten(y_true))
    loss = K.square(y_pred - y_true)
    loss = K.sum(loss, axis=1)
    return loss

# Model
def get_model():
    inputs = keras.Input(shape=(10))    
    outputs = layers.Dense(1, activation="linear")(inputs)
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    model.compile(
        optimizer="rmsprop",
        loss=my_mse,
    )
    return model

Dataset

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8)

epochs = 2

print ("Runs 1")
for e in range(epochs):
  for i, (x, y) in enumerate(train_dataset):
    print (e, i, y)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8)
print ("Runs 2")
for e in range(epochs):
  for i, (x, y) in enumerate(train_dataset):
    print (e, i, y)

Output:

Runs 1
0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32)
1 tf.Tensor([ 6. 11. 10. 14.  9. 12. 13.], shape=(7,), dtype=float32)
2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32)
3 tf.Tensor([13. 10.  0. 14.  6. 11. 12.], shape=(7,), dtype=float32)
4 tf.Tensor([ 0.  1.  5.  6.  9.  3.  7. 14.], shape=(8,), dtype=float32)
5 tf.Tensor([13.  8.  4. 10.  2. 12. 11.], shape=(7,), dtype=float32)
Runs 2
0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32)
1 tf.Tensor([ 6. 11. 10. 14.  9. 12. 13.], shape=(7,), dtype=float32)
2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32)
3 tf.Tensor([13. 10.  0. 14.  6. 11. 12.], shape=(7,), dtype=float32)
4 tf.Tensor([ 0.  1.  5.  6.  9.  3.  7. 14.], shape=(8,), dtype=float32)
5 tf.Tensor([13.  8.  4. 10.  2. 12. 11.], shape=(7,), dtype=float32)

Yes with the seed the order is reproduced.

Now let write a method to forward the dataset to a certain epoch and batch combination

def forward(dataset, n=None):
  if not n:
    return dataset

  i = 0  
  while True:
    for _ in dataset:        
        i += 1
        if i == n:
          return dataset

Test cases:

Lets run it normally and observe the order

Data from the beginning

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None)

model = get_model()
model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False)

Output:

[7 3 6 10]
[11 0 1 2]
[8 14 9 13]
[12 5 4]
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

Data from the nth state of Dataset

Let forward our dataset to 4th iteration and run the training

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4)

model = get_model()
model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False)

Output:

[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

Nice, now we know how to forward the dataset correctly. Lets now write callback to track the current iteration number:

Custom callback to track the iteration (epoch-batch combination)

Now we need to identify epoch and batch combination at which the model is check pointed. If we have this information we can load the last check pointed model and forward our dataset to its batch and epoch combination and continue the training. We will do this using the call backs

class MyCustomCallback(tf.keras.callbacks.ModelCheckpoint, keras.callbacks.Callback):
    def __init__(self, the_id=0, **args):
      self.the_id = the_id
      self.epoch = 0
      super().__init__(**args)

    def _save_model(self, epoch, logs):
      logs['the_id'] = self.the_id
      super()._save_model(epoch, logs)

    def on_batch_end(self, batch, logs={}):
      self.the_id += 1
      super().on_batch_end(batch, logs)

checkpoint_filepath = 'checkpoint-{the_id}'
model_checkpoint_callback = MyCustomCallback(
    filepath=checkpoint_filepath,
    save_freq=2,
    save_best_only=False)

model = get_model()

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None)

model.fit(train_dataset, epochs=5, verbose=0, callbacks=[model_checkpoint_callback], workers=4, shuffle=False)

Output:

[7 3 6 10]
[11 0 1 2]
[8 14 9 13]
[12 5 4]
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

We are check pointing for every two batches. So lets assume it crashes and the last checkpoint is checkpoint-4. We can load this model and forward our dataset to 4 and continue training.

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4)

model = get_model()
model.fit(train_dataset, epochs=2, verbose=0, workers=4, shuffle=False)

Output:

[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

Upvotes: 2

Dmitrii Rashchenko
Dmitrii Rashchenko

Reputation: 198

I suppose you want to restore shuffle order to avoid repetition of some samples inside this epoch.

According to shuffle description during not finished epoch you model had access only to the first current_step_number + shuffle_buffer_size samples from you dataset.

So when you restore you training if you know how many steps were processed, you can just skip this steps + skip shuffle_buffer_size steps and you training will be continued on following samples, which was not observed yet inside current epoch.

Note that some random shuffle_buffer_size samples from first part of dataset will not be observed at all during this epoch. As you say your epoch is very long, so, probably you have a lot of data, so losing shuffle_buffer_size samples should not be problem for you.

So during saving checkpoint also save step number, then after loading checkpoint create dataset copy with skipped steps (using dataset.skip), then use model.fit with this smaller dataset for one epoch (to finish current epoch), then continue your training in usual way.

Upvotes: 0

Related Questions