Reputation: 848
I would like to create a new MultiRNNCell while reusing the old weights.
Starting from TensorFlow 1.1 when you create MultiRNNCell, you have to explicitly create new cells. To reuse weights you have to provide a reuse=True
flag. In my code I currently have:
import tensorflow as tf
from tensorflow.contrib import rnn
def create_lstm_multicell():
lstm_cell = lambda: rnn.LSTMCell(nstates, reuse=tf.get_variable_scope().reuse)
lstm_multi_cell = rnn.MultiRNNCell([lstm_cell() for _ in range(n_layers)])
return lstm_multi_cell
When I create a first multicell the function should work as expected, and every cell inside multilayer element has independent weights and biases.
with tf.variable_scope('lstm') as scope:
lstm1 = create_lstm_multicell()
Now I would like to create another one:
with tf.variable_scope('lstm') as scope:
scope.reuse_variables()
lstm2 = create_lstm_multicell()
I would like the first cell from lstm2
to use the weights and biases of first cell from lstm1
, second cell to reuse weights and biases of the second cell, etc. But I suspect that since I call rnn.LSTMCell
with reuse=True
, weights & biases of the first cell will be reused all the time.
P.S. For the architectural reason I do not want to reuse lstm1
, I would like to create a new multicell lstm2
having the same weights.
Upvotes: 0
Views: 1868
Reputation: 848
TL;DR
It seems that in the code from the question weights and biases of cells will be reused properly. Multicells lstm1
and lstm2
will have identical behavior, and cells inside MultiRNNCell will have independent weights and biases. I.e. in a pseudocode:
lstm1._cells[0].weights == lstm2._cells[0].weights
lstm1._cells[1].weights == lstm2._cells[1].weights
Longer version
This is not a definitive answer so far, but that is result of research I made so far.
It looks like a hack, but we can override the get_variable
method to see which variables are accessed. For example like this:
from tensorflow.python.ops import variable_scope as vs
def verbose(original_function):
# make a new function that prints a message when original_function starts and finishes
def new_function(*args, **kwargs):
print('get variable:', '/'.join((tf.get_variable_scope().name, args[0])))
result = original_function(*args, **kwargs)
return result
return new_function
vs.get_variable = verbose(vs.get_variable)
Now we can run the following modified code:
def create_lstm_multicell(name):
def lstm_cell(i, s):
print('creating cell %i in %s' % (i, s))
return rnn.LSTMCell(nstates, reuse=tf.get_variable_scope().reuse)
lstm_multi_cell = rnn.MultiRNNCell([lstm_cell(i, name) for i in range(n_layers)])
return lstm_multi_cell
with tf.variable_scope('lstm') as scope:
lstm1 = create_lstm_multicell('lstm1')
layer1, _ = tf.nn.dynamic_rnn(lstm1, x, dtype=tf.float32)
val_1 = tf.reduce_sum(layer1)
with tf.variable_scope('lstm') as scope:
scope.reuse_variables()
lstm2 = create_lstm_multicell('lstm2')
layer2, _ = tf.nn.dynamic_rnn(lstm2, x, dtype=tf.float32)
val_2 = tf.reduce_sum(layer2)
The output will look like this (I removed repetitive lines):
creating cell 0 in lstm1
creating cell 1 in lstm1
get variable: lstm/rnn/multi_rnn_cell/cell_0/lstm_cell/weights
get variable: lstm/rnn/multi_rnn_cell/cell_0/lstm_cell/biases
get variable: lstm/rnn/multi_rnn_cell/cell_1/lstm_cell/weights
get variable: lstm/rnn/multi_rnn_cell/cell_1/lstm_cell/biases
creating cell 0 in lstm2
creating cell 1 in lstm2
get variable: lstm/rnn/multi_rnn_cell/cell_0/lstm_cell/weights
get variable: lstm/rnn/multi_rnn_cell/cell_0/lstm_cell/biases
get variable: lstm/rnn/multi_rnn_cell/cell_1/lstm_cell/weights
get variable: lstm/rnn/multi_rnn_cell/cell_1/lstm_cell/biases
This output indicates that lstm1
and lstm2
cells will use the same weights & biases, both have separates weights & biases for the first and second cells inside MultiRNNCell.
Additionally, val_1
and val_2
which are outputs of lstm1
and lstm2
are identical during optimization.
I think MultiRNNCell creates namespaces cell_0
, cell_1
etc. inside of it. And therefore the weights between lstm1
and lstm2
will be reused.
Upvotes: 3