K. Linke
K. Linke

Reputation: 1

Tensorflow: Save / Restore Variables Loading Incorrectly

I am having trouble saving / restoring a Tensorflow model that I trained. The visible issue is that restored model performs poorly because it uses random weights instead of the correct ones. Others have encountered this problem, but the root cause for me seems to be different, perhaps duplicated variable names.

I have verified that the weights in the checkpoint file are the same as the ones used in the trained model, and that the weights after restoration are completely different. I have also seen duplicate variables ending in _1 before saving. Even more unusually, there appears to be duplicate variables with the same names after restoring.

I have been careful to run tf.reset_default_graph() before first setting up the graph and again before restoring it. A potential problem may be that I run global_variables_initializer() before I restore the variables. If I do not, I receive an uninitialized variable error when I run my restored model.

I am using the latest version of tensorflow with GPU support. Here is the code, split into building, training/saving and restoring:

#build model
def model_ten_layer_unet(X,y,is_training):
    # define our weights (e.g. init_two_layer_convnet)
    M = 4

    # setup variables
    Wconv1 = tf.get_variable("Wconv1", shape=[3, 3, 3, 1, M*32]) #SAME: Never change output size
    bconv1 = tf.get_variable("bconv1", shape=[M*32])
    Wconv2 = tf.get_variable("Wconv2", shape=[3, 3, 3, M*32, M*32])
    bconv2 = tf.get_variable("bconv2", shape=[M*32])
    ...
    #Wconv12= tf.get_variable("Wconv12", shape=[3, 3, 3, 1, M*32]) #SAME: Final layer combines stacked outputs
    #bconv12= tf.get_variable("bconv12", shape=[1])

    # define our graph (e.g. two_layer_convnet)
    d1 = tf.nn.max_pool3d(tf.expand_dims(X,-1), ksize=[1,2,2,2,1], strides=[1,2,2,2,1], padding='VALID')
    print('d1 shape: ', d1.shape)
    ...
    a12 = tf.layers.conv3d_transpose(c10, filters=1, kernel_size=[3,3,3], strides=[2,2,2], padding='SAME')
    print('a12 shape: ', a12.shape)
    y_out = tf.add(tf.squeeze(a12, -1),X)
    print('y_out shape: ', y_out.shape)
    return y_out


def run_model(session, predict, loss_val, Xd, yd, mean_image,
              epochs=1, batch_size=64, print_every=100,
              training=None, plot_losses=False):

    # have tensorflow compute accuracy
    mse = tf.reduce_mean(tf.square(tf.subtract(predict, y)))
    accuracy = tf.sqrt(mse)/tf.reduce_mean(tf.cast(mean_image,tf.float32))

    # shuffle indicies
    train_indicies = np.arange(Xd.shape[0])
    np.random.shuffle(train_indicies)

    training_now = training is not None

    # setting up variables we want to compute (and optimizing)
    # if we have a training function, add that to things we compute
    # mse is a placeholder
    variables = [mean_loss, accuracy, mse]
    if training_now:
        variables[-1] = training

    # counter 
    iter_cnt = 0
    for e in range(epochs):
        # keep track of losses and accuracy
        losses = []
        accuracies = []
        # make sure we iterate over the dataset once
        for i in range(int(math.ceil(Xd.shape[0]/batch_size))):
            # generate indicies for the batch
            start_idx = (i*batch_size)%Xd.shape[0]
            idx = train_indicies[start_idx:start_idx+batch_size]

            # create a feed dictionary for this batch
            feed_dict = {X: Xd[idx,:],
                         y: yd[idx],
                         is_training: training_now }
            # get batch size
            actual_batch_size = yd[idx].shape[0]

            # have tensorflow compute loss and correct predictions
            # and (if given) perform a training step
            loss, accuracy, _ = session.run(variables,feed_dict=feed_dict)

            # aggregate performance stats
            losses.append(loss)
            accuracies.append(accuracy)

            # print every now and then
            if training_now and (iter_cnt % print_every) == 0:
                print("Iteration {0}: with minibatch training loss = {1:.3g} and accuracy of {2:.2g}"\
                      .format(iter_cnt,loss,accuracy))
            iter_cnt += 1
        total_accuracy = np.mean(accuracies)
        total_loss = np.mean(losses)
        ...
    return total_loss,total_accuracy, y_out


    tf.reset_default_graph()

    X = tf.placeholder(tf.float32, [None, 64, 64, 64])
    y = tf.placeholder(tf.float32, [None, 64, 64, 64])
    is_training = tf.placeholder(tf.bool)

    y_out = model_ten_layer_unet(X,y,is_training)
    total_loss = tf.square(tf.subtract(y_out, y))
    mean_loss = tf.reduce_mean(total_loss)

    global_step = tf.Variable(0, trainable=False)
    boundaries = [5000,10000,15000]
    values = [4e-4,2e-4,5e-5,1e-5]
    learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

    optimizer = tf.train.AdamOptimizer(learning_rate) # select optimizer and set learning rate

    # enable saving and loading models
    saver = tf.train.Saver()

    #set up path for saving models
    home = os.getenv("HOME")
    work_dir = 'assignment2'
    model_dir = 'models'
    model_name = 'unet1'
    model_path = os.path.join(home,work_dir,model_dir)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    model_pathname = os.path.join(model_path,model_name)

    # batch normalization in tensorflow requires this extra dependency
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(extra_update_ops):
        train_step = optimizer.minimize(mean_loss,global_step=global_step)


#train model
    with tf.Session() as sess:
        #with tf.device("/gpu:0"): #"/cpu:0" or "/gpu:0" 
            sess.run(tf.global_variables_initializer())
            #print(sess.run('Wconv1:0'))

            for e in range(1):
                print('Training epoch ', e+1)
                run_model(sess,y_out,mean_loss,X_train,y_train,mean_image,1,32,1,train_step,False)
                print('Validation epoch ', e+1)
                _,_,y_hat = run_model(sess,y_out,mean_loss,X_val,y_val,mean_image,1,32,1)

            print('Global Variables')
            for v in tf.global_variables():
                print(v.name)
            print('Local Variables')
            for v in tf.local_variables():
                print(v.name)

            saver.export_meta_graph(model_pathname, clear_devices=True)
            save_path = saver.save(sess, model_pathname, global_step=global_step)
            print("Model saved in path: %s" % save_path)

    #Output
    Training epoch  1
    Iteration 0: with minibatch training loss = 7.95e+03 and accuracy of 0.2
    ...
    Iteration 27: with minibatch training loss = 7.51e+03 and accuracy of 0.2
    Global Variables
    Wconv1:0
    bconv1:0
    Wconv2:0
    bconv2:0
    Wconv3:0
    bconv3:0
    conv3d_transpose/kernel:0
    conv3d_transpose/bias:0
    conv3d_transpose_1/kernel:0
    conv3d_transpose_1/bias:0
    Variable:0
    beta1_power:0
    beta2_power:0
    Wconv1/Adam:0
    Wconv1/Adam_1:0 **Duplicated variables with _1 before saving**
    bconv1/Adam:0
    bconv1/Adam_1:0
    Wconv2/Adam:0
    Wconv2/Adam_1:0
    bconv2/Adam:0
    bconv2/Adam_1:0
    Wconv3/Adam:0
    Wconv3/Adam_1:0
    bconv3/Adam:0
    bconv3/Adam_1:0
    conv3d_transpose/kernel/Adam:0
    conv3d_transpose/kernel/Adam_1:0
    ...
    conv3d_transpose_1/bias/Adam:0
    conv3d_transpose_1/bias/Adam_1:0



#restore model    
    tf.reset_default_graph()

    # Later, launch the model, use the saver to restore variables from disk, and
    # do some work with the model.
    with tf.Session() as sess:

        X = tf.placeholder(tf.float32, [None, 64, 64, 64])
        y = tf.placeholder(tf.float32, [None, 64, 64, 64])
        is_training = tf.placeholder(tf.bool)
        y_out = model_ten_layer_unet(X,y,is_training)
        total_loss = tf.square(tf.subtract(y_out, y))
        mean_loss = tf.reduce_mean(total_loss)

        sess.run(tf.global_variables_initializer())

        new_saver = tf.train.import_meta_graph(save_path+'.meta')

        # Restore variables from disk.
        new_saver.restore(sess, save_path)
        print("Model restored.")

        # Load output images
        print('Global Variables')
        for v in tf.global_variables():
            print(v.name)
        print('Local Variables')
        for v in tf.local_variables():
            print(v.name)
        _,_,y_hat = run_model(sess,y_out,mean_loss,X_val,y_val,mean_image,1,32,1)
        y_hat = y_hat.eval(feed_dict = {X:X_val[0:9,],is_training:False})


        #Output
        d1 shape:  (?, 32, 32, 32, 1)
        ...
        y_out shape:  (?, 64, 64, 64)
        Global Variables
        Wconv1:0
        bconv1:0
        Wconv2:0
        bconv2:0
        Wconv3:0
        bconv3:0
        conv3d_transpose/kernel:0
        conv3d_transpose/bias:0
        conv3d_transpose_1/kernel:0
        conv3d_transpose_1/bias:0
        Wconv1:0 **Repeated variables with the same name**
        bconv1:0
        Wconv2:0
        bconv2:0
        Wconv3:0
        bconv3:0
        conv3d_transpose/kernel:0
        conv3d_transpose/bias:0
        conv3d_transpose_1/kernel:0
        conv3d_transpose_1/bias:0
        Variable:0
        beta1_power:0
        beta2_power:0
        Wconv1/Adam:0
        Wconv1/Adam_1:0
        bconv1/Adam:0
        bconv1/Adam_1:0
        Wconv2/Adam:0
        Wconv2/Adam_1:0
        bconv2/Adam:0
        bconv2/Adam_1:0
        Wconv3/Adam:0
        Wconv3/Adam_1:0
        bconv3/Adam:0
        bconv3/Adam_1:0
        conv3d_transpose/kernel/Adam:0
        conv3d_transpose/kernel/Adam_1:0
        ...
        conv3d_transpose_1/bias/Adam:0
        conv3d_transpose_1/bias/Adam_1:0
        Local Variables
        Epoch 1, Overall loss = 9.01e+03 and accuracy of 0.214 **Poor validation performance**

Upvotes: 0

Views: 1016

Answers (1)

Noorul Hasan
Noorul Hasan

Reputation: 21

There is a lot of code you have posted that seems to take a lot of time to go through. Here is the simple example how to restore your latest checkpoint:

 with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, shape=[None, 784])
    output = inference(x)      # custom function
    saver = tf.train.Saver()
    sess = tf.Session()
    base_dir = os.path.dirname(os.path.abspath(__file__))
    checkpoint_path = os.path.join(os.path.dirname(base_dir), "logistic_regression/logistic_logs/")
    saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

    result = sess.run(tf.argmax(output,1), feed_dict={x: conv_mnist(img_name)})
    print("Result:", result)
    return result

I hope that this code sample may help you.

Upvotes: 1

Related Questions