Reputation: 1835
I have code for training a CNN using tf.train.MonitoredTrainingSession
.
When I create a new tf.train.MonitoredTrainingSession
I can pass the checkpoint
directory as an input parameter to the session and it will automatically restore the latest saved checkpoint
it could find. And I can set up the hooks
to train until some step. For example, if the checkpoint
's step is 150,000
and I would like to train until 200,000
I will put the last_step
to 200,000
.
The above process works perfectly as long as the latest checkpoint
was saved using a tf.train.MonitoredTrainingSession
. However, if I try to restore a checkpoint
that was saved using a normal tf.Session
then all hell breaks loose. It can't find some keys in the graph and all.
The training is done with this:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.retrain_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
If the checkpoint_dir
attribute had a folder with no checkpoints, this will start all over. If it had a checkpoint
that was saved from a previous training session, it will restore the latest checkpoint
and will continue training.
Now, I am restoring the latest checkpoint
and modifying some variables and saving them:
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
with tf.Session() as sess:
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
else:
print('No checkpoint file found')
return
prune_convs(sess)
saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)
As you can see, just before saver.save...
I am pruning all convolutional layers in the network. No need to describe how and why is that done. The point is that the network is in fact modified. Then I save the network to a checkpoint
.
Now, if I deploy test on the saved modified network, the test works just fine. However, when I try to run the tf.train.MonitoredTrainingSession
on the checkpoint
that was saved, it says:
Key conv1/weight_loss/avg not found in checkpoint
Also, I have noticed that the checkpoint
that was saved with tf.Session
has half of the size of the checkpoint
that was saved with tf.train.MonitoredTrainingSession
I know I'm doing it wrong, any suggestions how to make this work?
Upvotes: 1
Views: 1285
Reputation: 1835
I figured it out. Apparently, tf.Saver
does not restore all variables from a checkpoint
. I tried restoring and saving immediately and the output was half the size.
I used tf.train.list_variables
to get all variables from latest checkpoint
and then converted them into tf.Variable
and created a dict
from them. Then I passed the dict
to tf.Saver
and it restored all of my variables.
The next thing was to initialize
all of the variables and then modify the weights.
Now it is working.
Upvotes: 1