Reputation: 15673
I am practising with TensorFlow on this tutorial. The evaluate
function depends on the training to load the latest checkpoint:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(encoder=encoder,
decoder=decoder,
optimizer = optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
start_epoch = 0
if ckpt_manager.latest_checkpoint:
start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
ckpt.restore(ckpt_manager.latest_checkpoint)
for epoch in range(start_epoch, EPOCHS):
start = time.time()
total_loss = 0
for (batch, (img_tensor, target)) in enumerate(dataset):
batch_loss, t_loss = train_step(img_tensor, target)
total_loss += t_loss
if batch % 100 == 0:
print ('Epoch {} Batch {} Loss {:.4f}'.format(
epoch + 1, batch, batch_loss.numpy() / int(target.shape[1])))
loss_plot.append(total_loss / num_steps)
ckpt_manager.save()
Without ckpt_manager.save()
, the evaluation
function does not work.
When we have already trained a model and the checkpoints are available in checkpoint_path
. How should we load the model without training?
Upvotes: 3
Views: 2554
Reputation: 1455
You can use tf.train.latest_checkpoint
to get the latest checkpoint file and then load it manually using ckpt.restore
:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(encoder=encoder,
decoder=decoder,
ckpt_path = tf.train.latest_checkpoint(checkpoint_path)
ckpt.restore(ckpt_path)
Upvotes: 4