Googlebot
Googlebot

Reputation: 15673

How to load the last checkpoint in TensorFlow?

I am practising with TensorFlow on this tutorial. The evaluate function depends on the training to load the latest checkpoint:

checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(encoder=encoder,
                           decoder=decoder,
                           optimizer = optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

start_epoch = 0
if ckpt_manager.latest_checkpoint:
  start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
  ckpt.restore(ckpt_manager.latest_checkpoint)

for epoch in range(start_epoch, EPOCHS):
    start = time.time()
    total_loss = 0

    for (batch, (img_tensor, target)) in enumerate(dataset):
        batch_loss, t_loss = train_step(img_tensor, target)
        total_loss += t_loss

        if batch % 100 == 0:
            print ('Epoch {} Batch {} Loss {:.4f}'.format(
              epoch + 1, batch, batch_loss.numpy() / int(target.shape[1])))
    loss_plot.append(total_loss / num_steps)

    ckpt_manager.save()

Without ckpt_manager.save(), the evaluation function does not work.

When we have already trained a model and the checkpoints are available in checkpoint_path. How should we load the model without training?

Upvotes: 3

Views: 2554

Answers (1)

Abhinav Goyal
Abhinav Goyal

Reputation: 1455

You can use tf.train.latest_checkpoint to get the latest checkpoint file and then load it manually using ckpt.restore:

checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(encoder=encoder,
                           decoder=decoder,

ckpt_path = tf.train.latest_checkpoint(checkpoint_path)
ckpt.restore(ckpt_path)

Upvotes: 4

Related Questions