Reputation: 205
I am trying to reuse the Bidirectional LSTM weights for 2 very similar computations, but I am getting an error and have no idea what I am doing wrong. I have a class for the basic module :
class BasicAttn(object):
def __init__(self, keep_prob, value_vec_size):
self.rnn_cell_fw = rnn_cell.LSTMCell(value_vec_size/2, reuse=True)
self.rnn_cell_fw = DropoutWrapper(self.rnn_cell_fw, input_keep_prob=self.keep_prob)
self.rnn_cell_bw = rnn_cell.LSTMCell(value_vec_size/2, reuse=True)
self.rnn_cell_bw = DropoutWrapper(self.rnn_cell_bw, input_keep_prob=self.keep_prob)
def build_graph(self, values, values_mask, keys):
blended_reps = compute_blended_reps()
with tf.variable_scope('BasicAttn_BRNN', reuse=True):
(fw_out, bw_out), _ =
tf.nn.bidirectional_dynamic_rnn(self.rnn_cell_fw, self.rnn_cell_bw, blended_reps, dtype=tf.float32, scope='BasicAttn_BRNN')
Then, the module gets called while building the graph
attn_layer_start = BasicAttn(...)
blended_reps_start = attn_layer_start.build_graph(...)
attn_layer_end = BasicAttn(...)
blended_reps_end = attn_layer_end.build_graph(...)
But I get the error saying that TensorFlow is unable to reuse the RNNs?
ValueError: Variable QAModel/BasicAttn_BRNN/BasicAttn_BRNN/fw/lstm_cell/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope
There is a lot of code, so I have trimmed out the parts which I thought were unneccessary.
Upvotes: 1
Views: 77
Reputation: 12795
reuse=True
means that the variables have been created previously with reuse=False
, so each tf.get_variable
(in your case abstracted behind the LSTM interface) expects the variable to already exist.
To have a mode in which variables created if they do not exist yet, and reused otherwise, you need to set reuse=tf.AUTO_REUSE
(as the error message suggests).
So replace all occurrences of reuse=True
with reuse=tf.AUTO_REUSE
Here's the documentation: https://www.tensorflow.org/api_docs/python/tf/variable_scope
Upvotes: 1