Amruth Lakkavaram
Amruth Lakkavaram

Reputation: 1517

Restore a saved neural network in Tensorflow

Before marking my question as duplicate, I want you to understand that I have went through a lot of questions, but none of the solutions there were able to clear my doubts and solve my problem. I have a trained neural network which I want to save, and later use this model to test this model against test dataset.

I tried saving and restoring it, but I am not getting the expected results. Restoring doesn't seem to work, maybe I am using it wrongly, it is just using the values given by the global variable initializer.

This is the code I am using for saving the model.

 sess.run(tf.initializers.global_variables())
#num_epochs = 7
for epoch in range(num_epochs):  
  start_time = time.time()
  train_accuracy = 0
  train_loss = 0
  val_loss = 0
  val_accuracy = 0

  for bid in range(int(train_data_size/batch_size)):
     X_train_batch = X_train[bid*batch_size:(bid+1)*batch_size]
     y_train_batch = y_train[bid*batch_size:(bid+1)*batch_size]
     sess.run(optimizer, feed_dict = {x:X_train_batch, y:y_train_batch,prob:0.50})  

     train_accuracy = train_accuracy + sess.run(model_accuracy, feed_dict={x : X_train_batch,y:y_train_batch,prob:0.50})
     train_loss = train_loss + sess.run(loss_value, feed_dict={x : X_train_batch,y:y_train_batch,prob:0.50})

  for bid in range(int(val_data_size/batch_size)):
     X_val_batch = X_val[bid*batch_size:(bid+1)*batch_size]
     y_val_batch = y_val[bid*batch_size:(bid+1)*batch_size]
     val_accuracy = val_accuracy + sess.run(model_accuracy,feed_dict = {x:X_val_batch, y:y_val_batch,prob:0.75})
     val_loss = val_loss + sess.run(loss_value, feed_dict = {x:X_val_batch, y:y_val_batch,prob:0.75})

  train_accuracy = train_accuracy/int(train_data_size/batch_size)
  val_accuracy = val_accuracy/int(val_data_size/batch_size)
  train_loss = train_loss/int(train_data_size/batch_size)
  val_loss = val_loss/int(val_data_size/batch_size)


  end_time = time.time()


  saver.save(sess,'./blood_model_x_v2',global_step = epoch)  

After saving the model, the files are written in my working directory something like this.

blood_model_x_v2-2.data-0000-of-0001
blood_model_x_v2-2.index
blood_model_x_v2-2.meta

Similarly, v2-3, so on to v2-6, and then a 'checkpoint' file. I then tried restoring it using this code snippet (after initializing),but getting different results from the expected one. What am I doing wrong ?

saver = tf.train.import_meta_graph('blood_model_x_v2-5.meta')
saver.restore(test_session,tf.train.latest_checkpoint('./'))

Upvotes: 1

Views: 403

Answers (1)

Amir
Amir

Reputation: 16607

According to tensorflow docs:

Restore Restores previously saved variables.

This method runs the ops added by the constructor for restoring variables. It requires a session in which the graph was launched. The variables to restore do not have to have been initialized, as restoring is itself a way to initialize variables.

Let's see an example:

We save the model similar to this:

import tensorflow as tf

# Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# Define a test operation that we will restore
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object which will save all the variables
saver = tf.train.Saver()

# Run the operation by feeding input
print (sess.run(w4, feed_dict))
# Prints 24 which is sum of (w1+w2)*b1

# Now, save the graph
saver.save(sess, './ckpnt/my_test_model', global_step=1000)

And then load the trained model with:

import tensorflow as tf

sess = tf.Session()
# First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('./ckpnt/my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./ckpnt'))

# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print (sess.run(op_to_restore, feed_dict))
# This will print 60 which is calculated
# using new values of w1 and w2 and saved value of b1.

As you can see we do not initialize our session in the restoring part. There is better way to save and restore model with Checkpoint which allows you to check whether the model is restored correctly or not.

Upvotes: 2

Related Questions