Reputation: 241
running on tensorflow 1.3.0 GPU. I have trained a model in TF and saved just a single varialbe tensor using:
embeddings = tf.Variable(tf.random_uniform([4**kmer_len, embedding_size], -0.04, 0.04), name='Embeddings')
more code, variables...
saver = tf.train.Saver({"Embeddings": embeddings}) # saving only embeddings variable
some more code, training model...
saver.save(ses, './embeddings/embedding_mat') # saving the variable
Now, I have a different model in a different file and I would like to resotre just the single saved embeddings variable to it. the problem is that this new model has some more variables. Now, when I try to restore the variable by doing:
embeddings = tf.Variable(tf.random_uniform([4**kmer_len_emb, embedding_size], -0.04, 0.04), name='Embeddings')
dense1 = tf.layers.dense(inputs=kmer_flattened, units=200, activation=tf.nn.relu, use_bias=True)
ses = tf.Session()
init = tf.global_variables_initializer()
ses.run(init)
saver = tf.train.Saver()
saver.restore(ses, './embeddings/embedding_mat')
I'm getting a not found in checkpoint error. Any thoughts on how to deal with this? Thanks
Upvotes: 0
Views: 111
Reputation: 1104
You must create an Instance of Saver
just on that variable:
saver = tf.train.Saver(var_list=[embeddings])
This is saying to your Saver
instance to take care of restoring/saving only that particular variable of that graph, otherwise it will try to restore/save all the variables of the graph.
Upvotes: 1
Reputation: 2585
It is because it can't find the dense1
checkpoint. try this:
all_var = tf.global_variables()
var_to_restore = [v for v in all_var if v.name == 'Embeddings:0']
ses.run(init)
saver = tf.train.Saver(var_to_restore)
saver.restore(ses, './embeddings/embedding_mat')
Upvotes: 1