Iakov Davydov
Iakov Davydov

Reputation: 848

How to reuse weights in MultiRNNCell?

I would like to create a new MultiRNNCell while reusing the old weights.

Starting from TensorFlow 1.1 when you create MultiRNNCell, you have to explicitly create new cells. To reuse weights you have to provide a reuse=True flag. In my code I currently have:

import tensorflow as tf
from tensorflow.contrib import rnn

def create_lstm_multicell():
    lstm_cell = lambda: rnn.LSTMCell(nstates, reuse=tf.get_variable_scope().reuse)
    lstm_multi_cell = rnn.MultiRNNCell([lstm_cell() for _ in range(n_layers)])
    return lstm_multi_cell

When I create a first multicell the function should work as expected, and every cell inside multilayer element has independent weights and biases.

with tf.variable_scope('lstm') as scope:
    lstm1 = create_lstm_multicell()

Now I would like to create another one:

with tf.variable_scope('lstm') as scope:
    scope.reuse_variables()
    lstm2 = create_lstm_multicell()

I would like the first cell from lstm2 to use the weights and biases of first cell from lstm1, second cell to reuse weights and biases of the second cell, etc. But I suspect that since I call rnn.LSTMCell with reuse=True, weights & biases of the first cell will be reused all the time.

  1. How do I ensure that weights are reused properly?
  2. If they are not, how to force this desired behavior?

P.S. For the architectural reason I do not want to reuse lstm1, I would like to create a new multicell lstm2 having the same weights.

Upvotes: 0

Views: 1868

Answers (1)

Iakov Davydov
Iakov Davydov

Reputation: 848

TL;DR

It seems that in the code from the question weights and biases of cells will be reused properly. Multicells lstm1 and lstm2 will have identical behavior, and cells inside MultiRNNCell will have independent weights and biases. I.e. in a pseudocode:

lstm1._cells[0].weights == lstm2._cells[0].weights
lstm1._cells[1].weights == lstm2._cells[1].weights

Longer version

This is not a definitive answer so far, but that is result of research I made so far.

It looks like a hack, but we can override the get_variable method to see which variables are accessed. For example like this:

from tensorflow.python.ops import variable_scope as vs

def verbose(original_function):
    # make a new function that prints a message when original_function starts and finishes
    def new_function(*args, **kwargs):
        print('get variable:', '/'.join((tf.get_variable_scope().name, args[0])))
        result = original_function(*args, **kwargs)
        return result
    return new_function

vs.get_variable = verbose(vs.get_variable)

Now we can run the following modified code:

def create_lstm_multicell(name):
    def lstm_cell(i, s):
        print('creating cell %i in %s' % (i, s))
        return rnn.LSTMCell(nstates, reuse=tf.get_variable_scope().reuse)
    lstm_multi_cell = rnn.MultiRNNCell([lstm_cell(i, name) for i in range(n_layers)])
    return lstm_multi_cell

with tf.variable_scope('lstm') as scope:
    lstm1 = create_lstm_multicell('lstm1')
    layer1, _ = tf.nn.dynamic_rnn(lstm1, x, dtype=tf.float32)
    val_1 = tf.reduce_sum(layer1)

with tf.variable_scope('lstm') as scope:
    scope.reuse_variables()
    lstm2 = create_lstm_multicell('lstm2')
    layer2, _ = tf.nn.dynamic_rnn(lstm2, x, dtype=tf.float32)
    val_2 = tf.reduce_sum(layer2)

The output will look like this (I removed repetitive lines):

creating cell 0 in lstm1
creating cell 1 in lstm1
get variable: lstm/rnn/multi_rnn_cell/cell_0/lstm_cell/weights
get variable: lstm/rnn/multi_rnn_cell/cell_0/lstm_cell/biases
get variable: lstm/rnn/multi_rnn_cell/cell_1/lstm_cell/weights
get variable: lstm/rnn/multi_rnn_cell/cell_1/lstm_cell/biases
creating cell 0 in lstm2
creating cell 1 in lstm2
get variable: lstm/rnn/multi_rnn_cell/cell_0/lstm_cell/weights
get variable: lstm/rnn/multi_rnn_cell/cell_0/lstm_cell/biases
get variable: lstm/rnn/multi_rnn_cell/cell_1/lstm_cell/weights
get variable: lstm/rnn/multi_rnn_cell/cell_1/lstm_cell/biases

This output indicates that lstm1 and lstm2 cells will use the same weights & biases, both have separates weights & biases for the first and second cells inside MultiRNNCell.

Additionally, val_1 and val_2 which are outputs of lstm1 and lstm2 are identical during optimization.

I think MultiRNNCell creates namespaces cell_0, cell_1 etc. inside of it. And therefore the weights between lstm1 and lstm2 will be reused.

Upvotes: 3

Related Questions