ALeex
ALeex

Reputation: 187

Tensorflow save/restore batch norm

I trained a model with batch norm in Tensorflow. I would like to save the model and restore it for further using. The batch norm is done by

def batch_norm(input, phase):
    return tf.layers.batch_normalization(input, training=phase)

where the phase is True during training and False during testing.

It seems like simply calling

saver = tf.train.Saver()
saver.save(sess, savedir + "ckpt")

would not work well because when I restore the model it first says restored successfully. It also says Attempting to use uninitialized value batch_normalization_585/beta if I just run one node in the graph. Is this related to not saving the model properly or something else that I've missed?

Upvotes: 7

Views: 7397

Answers (2)

simo23
simo23

Reputation: 506

I also had the "Attempting to use uninitialized value batch_normalization_585/beta" error. This comes from the fact that by declaring the saver with the empty brackets like this:

         saver = tf.train.Saver() 

The saver will save the variables contained in tf.trainable_variables() which do not contain the moving average of the batch normalization. To include this variables into the saved ckpt you need to do:

         saver = tf.train.Saver(tf.global_variables())

Which saves ALL the variables, so it is very memory consuming. Or you must identify the variables that have moving avg or variance and save them by declaring them like:

         saver = tf.train.Saver(tf.trainable_variables() + list_of_extra_variables)

Upvotes: 8

javidcf
javidcf

Reputation: 59731

Not sure if this needs to be explained, but just in case (and for other potential viewers).

Whenever you create an operation in TensorFlow, a new node is added to the graph. No two nodes in a graph can have the same name. You can define the name of any node you create, but if you don't give a name, TensorFlow will pick one for you in a deterministic way (that is, not randomly, but instead always with the same sequence). If you add two numbers, it will probably be Add, but if you do another addition, since no two nodes can have the same name, it may be something like Add_2. Once a node is created in a graph its name cannot be changed. Many functions create several subnodes in turn; for example, tf.layers.batch_normalization creates some internal variables beta and gamma.

Saving and restoring works in the following way:

  1. You create a graph representing the model that you want. This graph contains the variables that will be saved by the saver.
  2. You initialize, train or do whatever you want with that graph, and the variables in the model get assigned some values.
  3. You call save on the saver to, well, save the values of the variables to a file.
  4. Now you recreate the model in a different graph (it can be a different Python session altogether or just another graph coexisting with the first one). The model must be created in exactly the same way the first one was.
  5. You call restore on the saver to retrieve the values of the variables.

In order for this to work, the names of the variables in the first and the second graph must be exactly the same.

In your example, TensorFlow is complaining about the variable batch_normalization_585/beta. It seems that you have called tf.layers.batch_normalization nearly 600 times in the same graph, so you have that many beta variables hanging around. I doubt that you actually need that many, so I guess you are just experimenting with the API and ended up with that many copies.

Here's a draft of something that should work:

import tensorflow as tf

def make_model():
    input = tf.placeholder(...)
    phase = tf.placeholder(...)
    input_norm = tf.layers.batch_normalization(input, training=phase))
    # Do some operations with input_norm
    output = ...
    saver = tf.train.Saver()
    return input, output, phase, saver

# We work with one graph first
g1 = tf.Graph()
with g1.as_default():
    input, output, phase, saver = make_model()
    with tf.Session() as sess:
        # Do your training or whatever...
        saver.save(sess, savedir + "ckpt")

# We work with a second different graph now
g2 = tf.Graph()
with g2.as_default():
    input, output, phase, saver = make_model()
    with tf.Session() as sess:
        saver.restore(sess, savedir + "ckpt")
        # Continue using your model...

Again, the typical case is not to have two graphs side by side, but rather have one graph and then recreate it in another Python session later, but in the end both things are the same. The important part is that the model is created in the same way (and therefore with the same node names) in both cases.

Upvotes: 4

Related Questions