Redfox-Codder
Redfox-Codder

Reputation: 183

Tensorflow checkpoints are being overwritten

I'm training a model (a generative adversarial network) over an input-set using Tensorflow, and I would like to save model's parameters every 50 epochs.

Let say that I want to train the model for 1000 epochs, and save the model's parameters every 50 epoch, which would end up having 20 different checkpoint files.

By having a Session, and a Saver object, I simply use the following code to do so.

if num_epoch % 50 == 0:
    saver.save(sess=sess, path='RGAN-1/sv/' + type_exp, global_step=num_epoch)

The problem is, that checkpoints are getting overwritten, and at the end of the experiment, I only have the last 6 checkpoints, while I should have 20 checkpoints.

I have no idea why this is happening.

Upvotes: 0

Views: 621

Answers (1)

xdurch0
xdurch0

Reputation: 10474

tf.train.Saver has a max_to_keep argument that is set to 5 by default. You can pass 0 to keep all checkpoints:

saver = tf.train.Saver(..., max_to_keep=0)

See the docs for a full argument list.

Upvotes: 2

Related Questions