user3921
user3921

Reputation: 241

Restoring a single variable tensor saved in one model to variable tensor in another model - Tensorflow

running on tensorflow 1.3.0 GPU. I have trained a model in TF and saved just a single varialbe tensor using:

embeddings = tf.Variable(tf.random_uniform([4**kmer_len, embedding_size], -0.04, 0.04), name='Embeddings')
more code, variables...
saver = tf.train.Saver({"Embeddings": embeddings})  # saving only embeddings variable
some more code, training model...
saver.save(ses, './embeddings/embedding_mat')       # saving the variable 

Now, I have a different model in a different file and I would like to resotre just the single saved embeddings variable to it. the problem is that this new model has some more variables. Now, when I try to restore the variable by doing:

embeddings = tf.Variable(tf.random_uniform([4**kmer_len_emb, embedding_size], -0.04, 0.04), name='Embeddings')
dense1 = tf.layers.dense(inputs=kmer_flattened, units=200, activation=tf.nn.relu, use_bias=True)  
ses = tf.Session()
init = tf.global_variables_initializer()
ses.run(init)
saver = tf.train.Saver()
saver.restore(ses, './embeddings/embedding_mat')

I'm getting a not found in checkpoint error. Any thoughts on how to deal with this? Thanks

Upvotes: 0

Views: 111

Answers (2)

Giuseppe Marra
Giuseppe Marra

Reputation: 1104

You must create an Instance of Saver just on that variable:

saver = tf.train.Saver(var_list=[embeddings])

This is saying to your Saver instance to take care of restoring/saving only that particular variable of that graph, otherwise it will try to restore/save all the variables of the graph.

Upvotes: 1

one
one

Reputation: 2585

It is because it can't find the dense1 checkpoint. try this:

all_var = tf.global_variables()

var_to_restore = [v for v in all_var if v.name == 'Embeddings:0']
ses.run(init)

saver = tf.train.Saver(var_to_restore)
saver.restore(ses, './embeddings/embedding_mat')

Upvotes: 1

Related Questions