Reputation: 2104
So, I trained a tensorflow model with a few layers, more or less like this:
with tf.variable_scope('model1') as scope:
inputs = tf.placeholder(tf.int32, [None, num_time_steps])
embeddings = tf.get_variable('embeddings', (vocab_size, embedding_size))
lstm = tf.nn.rnn_cell.LSTMCell(lstm_units)
embedded = tf.nn.embedding_lookup(embeddings, inputs)
_, state = tf.nn.dynamic_rnn(lstm, embedded, dtype=tf.float32, scope=scope)
# more stuff on the state
Now, I wanted to reuse the embedding matrix and the lstm weights in another model, which is very different from this one except for these two components.
As far as I know, if I load them with a tf.Saver
object, it will look for
variables with the exact same names, but I'm using different variable_scope
s in the two graphs.
In this answer, it is suggested to create the graph where the LSTM is trained as a superset of the other one, but I don't think it is possible in my case, given the differences in the two models. Anyway, I don't think it is a good idea to make one graph dependent on the other, if they do independent things.
I thought about changing the variable scope of the LSTM weights and embeddings in the serialized graph. I mean, where it originally read model1/Weights:0
or something, it would be another_scope/Weights:0
. Is it possible and feasible?
Of course, if there is a better solution, it is also welcome.
Upvotes: 2
Views: 872
Reputation: 2104
I found out that the Saver can be initialized with a dictionary mapping variable names (without the trailing :0
) in the serialized file to the variable objects I want to restore in the graph. For example:
varmap = {'model1/some_scope/weights': variable_in_model2,
'model1/another_scope/weights': another_variable_in_model2}
saver = tf.train.Saver(varmap)
saver.restore(sess, path_to_saved_file)
Upvotes: 1