AnnaR
AnnaR

Reputation: 361

How to restore an LSTM layer

I would really appreciate it if I could get some help in saving and restoring LSTMs.

I have this LSTM layer -

# LSTM cell
cell = tf.contrib.rnn.LSTMCell(n_hidden)
output, current_state = tf.nn.dynamic_rnn(cell, word_vectors, dtype=tf.float32)

outputs = tf.transpose(output, [1, 0, 2])
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1)

# Saver function
saver = tf.train.Saver()
saver.save(sess, 'test-model')

The saver saves the model and allows me to save and restore the weights and biases of the LSTM. However, I need to restore this LSTM layer and feed it a new set of inputs.

To restore the entire model, I'm doing:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('test-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
  1. Is it possible for me to initialize an LSTM cell with the pre-trained weights and biases?

  2. If not, how do I restore this LSTM layer?

Thank you very much!

Upvotes: 7

Views: 638

Answers (1)

Vijay Mariappan
Vijay Mariappan

Reputation: 17191

You are already loading the model, and so the weights of the model. All you need to do is use get_tensor_by_name to get any tensor from the graph and use it for inference.

Example:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('test-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))

   # Get the tensors by their variable name
   word_vec = = detection_graph.get_tensor_by_name('word_vec:0')
   output_tensor = detection_graph.get_tensor_by_name('outputs:0')

   sess.run(output_tensor, feed_dict={word_vec: ...}) 

In the above example word_vec and outputs are names assigned to the tensors during creation of the graph. Make sure you assign names, so that they can be called by their name.

Upvotes: 1

Related Questions