erickrf
erickrf

Reputation: 2104

Reusing part of a tensorflow trained graph

So, I trained a tensorflow model with a few layers, more or less like this:

with tf.variable_scope('model1') as scope:

    inputs = tf.placeholder(tf.int32, [None, num_time_steps])
    embeddings = tf.get_variable('embeddings', (vocab_size, embedding_size))
    lstm = tf.nn.rnn_cell.LSTMCell(lstm_units)

    embedded = tf.nn.embedding_lookup(embeddings, inputs)
    _, state = tf.nn.dynamic_rnn(lstm, embedded, dtype=tf.float32, scope=scope)

    # more stuff on the state

Now, I wanted to reuse the embedding matrix and the lstm weights in another model, which is very different from this one except for these two components.

As far as I know, if I load them with a tf.Saver object, it will look for variables with the exact same names, but I'm using different variable_scopes in the two graphs.

In this answer, it is suggested to create the graph where the LSTM is trained as a superset of the other one, but I don't think it is possible in my case, given the differences in the two models. Anyway, I don't think it is a good idea to make one graph dependent on the other, if they do independent things.

I thought about changing the variable scope of the LSTM weights and embeddings in the serialized graph. I mean, where it originally read model1/Weights:0 or something, it would be another_scope/Weights:0. Is it possible and feasible?

Of course, if there is a better solution, it is also welcome.

Upvotes: 2

Views: 872

Answers (1)

erickrf
erickrf

Reputation: 2104

I found out that the Saver can be initialized with a dictionary mapping variable names (without the trailing :0) in the serialized file to the variable objects I want to restore in the graph. For example:

varmap = {'model1/some_scope/weights': variable_in_model2,
          'model1/another_scope/weights': another_variable_in_model2}

saver = tf.train.Saver(varmap)
saver.restore(sess, path_to_saved_file)

Upvotes: 1

Related Questions