Lukeyb
Lukeyb

Reputation: 857

Making a trainable initial state for an LSTM in TensorFlow

I have a sequence which is too long to fit in memory, but the initial state is quite critical so I would like to train that as a variable too. How can I train the initial state variable to pass in at the start of the sequence, but keep using the output state for the rest of the sequence?

This is what I've got so far:

    cell = tf.contrib.rnn.BasicLSTMCell(num_lstm_cells, state_is_tuple=True)

    init_vars = cell.zero_state(batch_size, tf.float32)
    init_c = tf.Variable(init_vars.c, trainable=True)
    init_h = tf.Variable(init_vars.h, trainable=True)
    init_state = tf.contrib.rnn.LSTMStateTuple(init_c, init_h)

    state_vars = cell.zero_state(batch_size, tf.float32)
    state_c = tf.Variable(state_vars.c, trainable=False)
    state_h = tf.Variable(state_vars.h, trainable=False)
    state = tf.contrib.rnn.LSTMStateTuple(state_c, state_h)

    layer = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=0.7)
    val, new_state = tf.nn.dynamic_rnn(layer, lstm_input, initial_state=state, dtype=tf.float32)

    with tf.control_dependencies([state[0].assign(new_state[0]), state[1].assign(new_state[1])]):
        output = tf.identity(val)

    inititalise_c = tf.assign(state[0], init_state[0])
    inititalise_h = tf.assign(state[1], init_state[1])
    initialise_state = tf.group([inititalise_c, inititalise_h])

The idea is that I have a trainable initial state variable (init_vars), and a non-trainable state (state_vars) which I assign the initial state to at the start of each sequence by calling the initialise_state op.

I don't think this will work though because the init_state isn't actually part of the training, it is just getting used for copying. How can I do this?

edit: I've confirmed in testing that the initial state is not being trained and remaining all 0's

Upvotes: 3

Views: 2780

Answers (3)

Lukeyb
Lukeyb

Reputation: 857

I ended up solving this by creating an initial state variable inside a separate variable scope. Then using the var_list optional parameter in Optimizer.Minimize(), I could specify to train the initial state at the start of each sequence. After the training the initial state, I would copy it to this separate variable scope, and train the graph for the the rest of the sequence.

    with tf.variable_scope("state"):
        state_c = tf.Variable(tf.random_uniform([batch_size, num_lstm_cells], 0, 1), trainable=True)
        state_h = tf.Variable(tf.random_uniform([batch_size, num_lstm_cells], 0, 1), trainable=True)
        state = tf.contrib.rnn.LSTMStateTuple(state_c, state_h)

    with tf.variable_scope("nn"):
        layer = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=0.7)
        val, new_state = tf.nn.dynamic_rnn(layer, lstm_input, initial_state=state, dtype=tf.float32)

        logits = tf.layers.dense(val, units=5, activation=tf.nn.relu)
        losses = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=targets)

    init_c = tf.Variable(tf.zeros([batch_size, num_lstm_cells]), trainable=False)
    init_h = tf.Variable(tf.zeros([batch_size, num_lstm_cells]), trainable=False)
    init_state = tf.contrib.rnn.LSTMStateTuple(init_c, init_h)

    restore_c = tf.assign(state[0], init_state[0])
    restore_h = tf.assign(state[1], init_state[1])
    restore_state = tf.group([restore_c, restore_h])

    save_c = tf.assign(init_state[0], state[0])
    save_h = tf.assign(init_state[1], state[1])
    save_state = tf.group([save_c, save_h])

    propagate_c = tf.assign(state[0], new_state[0])
    propagate_h = tf.assign(state[1], new_state[1])
    propagate_state = tf.group([propagate_c, propagate_h])

    nn_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "nn")
    state_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "state")

    total_loss = tf.reduce_mean(losses)

    train_nn_step = tf.train.AdamOptimizer().minimize(total_loss, var_list=nn_vars)
    train_nn_state_step = tf.train.AdamOptimizer().minimize(total_loss, var_list=[nn_vars, state_vars])

So you start a sequence by calling:

  1. sess.run(restore_state) to copy the initial state back to the graph
  2. _, er = sess.run([train_nn_state_step, error]) to train the initial state and nn
  3. sess.run(save_state) to save the initial state
  4. sess.run(propagate_state) to propagate the state to the next train step

And you train the rest of the sequence by calling:

  1. _, er = sess.run([train_nn_step, error]) to just train the neural network
  2. sess.run(propagate_state) to keep passing the state through

Upvotes: 2

amin__
amin__

Reputation: 1058

I am not sure what you want to do but, why don't you assign new_state to another state variable like below,

batch_size = 10
num_lstm_cells = 20
num_times = 5
input_dims = 6

lstm_input = tf.random_normal([batch_size, num_times, input_dims],0.,1.0)

cell = tf.contrib.rnn.BasicLSTMCell(num_lstm_cells, state_is_tuple=True)

init_vars = cell.zero_state(batch_size, tf.float32)
init_c = tf.Variable(init_vars.c, trainable=True)
init_h = tf.Variable(init_vars.h, trainable=True)
init_state = tf.contrib.rnn.LSTMStateTuple(init_c, init_h)

state_vars = cell.zero_state(batch_size, tf.float32)
state_c = tf.Variable(state_vars.c, trainable=False)
state_h = tf.Variable(state_vars.h, trainable=False)
state = tf.contrib.rnn.LSTMStateTuple(state_c, state_h)

layer = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=0.7)
val, new_state = tf.nn.dynamic_rnn(layer, lstm_input, initial_state=state, dtype=tf.float32)

trained_state_c = tf.assign(state[0], new_state[0])
trained_state_h = tf.assign(state[1], new_state[1])
trained_state = tf.contrib.rnn.LSTMStateTuple(trained_state_c, trained_state_h)

Upvotes: 0

ikamen
ikamen

Reputation: 3493

How about switching between training the network and initial state? Freeze the model, make initial state trainable, train for some time. Then switch the freezing.

Upvotes: 0

Related Questions