bge0
bge0

Reputation: 921

Tensorflow: How to get all variables from rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

I have a setup where I need to initialize an LSTM after the main initialization which uses tf.initialize_all_variables(). I.e. I want to call tf.initialize_variables([var_list])

Is there way to collect all the internal trainable variables for both:

so that I can initialize JUST these parameters?

The main reason I want this is because I do not want to re-initialize some trained values from earlier.

Upvotes: 17

Views: 11760

Answers (2)

Minjoon Seo
Minjoon Seo

Reputation: 546

You can also use tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

(partly copied from Rafal's answer)

Note that the last line is equivalent to the list comprehension in Rafal's code.

Basically, tensorflow stores a global collection of variables, which can be fetched by either tf.all_variables() or tf.get_collection(tf.GraphKeys.VARIABLES). If you specify scope (scope name) in the tf.get_collection() function, then you only fetch tensors (variables in this case) in the collection whose scopes are under the specified scope.

EDIT: You can also use tf.GraphKeys.TRAINABLE_VARIABLES to get trainable variables only. But since vanilla BasicLSTMCell does not initialize any non-trainable variable, both will be functionally equivalent. For a complete list of default graph collections, check this out.

Upvotes: 11

Rafał Józefowicz
Rafał Józefowicz

Reputation: 6235

The easiest way to solve your problem is to use variable scope. The names of the variables within a scope will be prefixed with its name. Here is a short snippet:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

It would work the same way with MultiRNNCell.

EDIT: changed tf.trainable_variables to tf.all_variables()

Upvotes: 17

Related Questions