Reputation: 21
I'm having an issue restoring some variables. I've already restored variables when I saved the whole model on a higher level, but this time I've decided to only restore a few variables. Before the first session, I initialize the weights:
weights = {
'1': tf.Variable(tf.random_normal([n_input, n_hidden_1], mean=0, stddev=tf.sqrt(2*1.67/(n_input+n_hidden_1))), name='w1')
}
weights_saver = tf.train.Saver(var_list=weights)
Then, in a session, while I train the NN:
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
[...]
weights_saver.save(sess, './savedModels/Weights/weights')
Then :
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(pathsToVariables + 'Weights/weights.meta')
new_saver.restore(sess, pathsToVariables + 'Weights/weights')
weights =
{
'1': tf.Variable(sess.graph.get_tensor_by_name("w1:0"), name='w1', trainable=False)
}
sess.run(tf.global_variables_initializer())
print(sess.run(weights['1']))
But at this point, the weights restored seem to be random. And indeed, If I do sess.run(tf.global_variables_initializer())
again, the weights will be different. As if, I restored the normal function of the initialization of the weights but not the trained weights.
What am I doing wrong?
Is my issue clear?
Upvotes: 0
Views: 552
Reputation: 21
weights =
{
'1': tf.Variable(sess.run(sess.graph.get_tensor_by_name("w1:0")), name='w1', trainable=False)
}
I found out the answer. I needed to run the tensors to get the values. It seems obvious now.
edit 2 :
This way is not a good way to initialize tensors from other values because it will create 2 tensors with the same name when we restore and then create the tensor. Or, if different names, it will restore the variable from the past model and may try to optimize it later on. It is better to restore the variable in a previous session, store the values, then close the session, open a new one to create everything else.
with tf.session() as sess:
weight1 = sess.run(sess.graph.get_tensor_by_name("w1:0"))
tf.reset_default_graph() #this will eliminate the variables we restored
with tf.session() as sess:
weights =
{
'1': tf.Variable(weight1 , name='w1-bis', trainable=False)
}
...
We are now sure the restored variables are not a part of the graph.
Upvotes: 1