Reputation: 361
I would really appreciate it if I could get some help in saving and restoring LSTMs.
I have this LSTM layer -
# LSTM cell
cell = tf.contrib.rnn.LSTMCell(n_hidden)
output, current_state = tf.nn.dynamic_rnn(cell, word_vectors, dtype=tf.float32)
outputs = tf.transpose(output, [1, 0, 2])
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1)
# Saver function
saver = tf.train.Saver()
saver.save(sess, 'test-model')
The saver saves the model and allows me to save and restore the weights and biases of the LSTM. However, I need to restore this LSTM layer and feed it a new set of inputs.
To restore the entire model, I'm doing:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('test-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
Is it possible for me to initialize an LSTM cell with the pre-trained weights and biases?
If not, how do I restore this LSTM layer?
Thank you very much!
Upvotes: 7
Views: 638
Reputation: 17191
You are already loading the model, and so the weights of the model. All you need to do is use get_tensor_by_name
to get any tensor from the graph and use it for inference.
Example:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('test-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
# Get the tensors by their variable name
word_vec = = detection_graph.get_tensor_by_name('word_vec:0')
output_tensor = detection_graph.get_tensor_by_name('outputs:0')
sess.run(output_tensor, feed_dict={word_vec: ...})
In the above example word_vec
and outputs
are names assigned to the tensors during creation of the graph. Make sure you assign names, so that they can be called by their name.
Upvotes: 1