wei he
wei he

Reputation: 13

How to save tensorflow dynamic_rnn model and restore them as an decoder in a new encoder-decoder model?

I am trying to train a encoder-decoder model to automatically generate summary. the encoder part use CNN to encode article's abstract. the decoder part is RNN to generate article's title.

so the skeleton looks like:

encoder_state = CNNEncoder(encoder_inputs)
decoder_outputs, _ = RNNDecoder(encoder_state,decoder_inputs)

But I want to pre-trained the RNN decoder to teach the model to learn how to speak first. the decoder part is:

def RNNDecoder(encoder_state,decoder_inputs):
    decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)
    #from tensorflow.models.rnn import rnn_cell, seq2seq
    cell = rnn.GRUCell(memory_dim)
    decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
        cell, decoder_inputs_embedded,
        initial_state=encoder_state,
        dtype=tf.float32,scope="plain_decoder1"
    )
    return decoder_outputs, decoder_final_state

So my concern is how to save the save and restore RNNDecoder part separately?

Upvotes: 0

Views: 822

Answers (1)

Shamane Siriwardhana
Shamane Siriwardhana

Reputation: 4201

Here you can take the output of the dynamic RNN first.

decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(decoder_cell, decoder_inputs_embedded,initial_state=encoder_final_state,dtype=tf.float32, time_major=True, scope="plain_decoder")

Take the decoder_outputs. Then use a softmax layer to fully connect it.

decoder_logits = tf.contrib.layers.linear(decoder_outputs, vocab_`size)

Then you can create a softmax loss with decoder_logits and train it in the noramal way.

When you want to restore the parameters you this kind of method in a session

with tf.Session() as session:
        saver = tf.train.Saver()
        saver.restore(session, checkpoint_file)

Here the checkpoint file should be your exact checkpoint file. So when running what happen is it will only restore your decoder weights and train with the main model.

Upvotes: 1

Related Questions