Reputation: 75
I have started exploring TensorFlow library and trying out Image classification example that uses MNIST data. I want the model to be stored in a file after training phase is over, so that I can use it as and when required. I have checked this link that tells about how to save values from TensorFlow to any file and read that as well. So far, I am able to save some variables from the script to file using pickle as suggested in the link. But, I am not able grasp on what needs to saved in file to store current state of the model for its later use. Please, someone could explain that part with example on storing a model and loading that model.
Upvotes: 0
Views: 2376
Reputation: 7150
Only Variables can be saved and restored. When you need to reuse the saved variables you need to first build the graph by creating the neural network and setting the NN's parameters like layer numbers, learning rate and dropout and etc. The only values restored from a checkpoint are the Variables defined in the training process. You can take a look at any example, for example this one.
To sum up, only Variables can and need to be saved and restored, neural network configures and placeholders can not.
Upvotes: 1
Reputation: 313
To save and restore variables in Tensorflow, the following things are needed.
1) A list of variables to be saved and restored 2) tf.train.Saver
Generally, 1) is achieved by
# To save and restore whole tf variables
all_vars = tf.global_variables()
or,
# To save and restore the specific tf variables using scope
all_vars = tf.global_variables()
model_vars = [k for k in all_vars if k.name.startswith("xxx")]
# "xxx" is the expected scope
Then, 2) is achieved by
saver = tf.train.Saver(vars_list)
# vars_list is list of variables from above
Finally, to save variables, (with tf.Session() running with named 'sess')
saver.save(sess, '/directory/to/chechpoint/file.ckpt')
and to restore them,
saver.restore(sess, '/directory/to/chechpoint/file.ckpt')
Upvotes: 2
Reputation: 1264
First, you should check out this other question.
TensorFlow has methods implemented for managing both saving and restoring checkpoints, specifically, the tf.train.saver
class. Check out the official documentation here. Checkpoints basically store the values of your tensors (among other things) in disk.
Citing the documentation:
Checkpoints are binary files in a proprietary format which map variable names to tensor values. The best way to examine the contents of a checkpoint is to load it using a
Saver
.
Upvotes: 0