DanielTheRocketMan
DanielTheRocketMan

Reputation: 3249

Saving and Restoring a model using tensorflow

I saved parameters of my neural network using this:

parameters = {
    'w_h1': w_h1,
    'b_h1': b_h1,
    'w_h2':  w_h2,
    'b_h2': b_h2,
    'w_h3': w_h3,
    'b_h3': b_h3,
    'w_o':  w_o,
    'b_o':  b_o
} 

saver = tf.train.Saver(parameters)

saver.save(sess, 'my-model', global_step=epoch)

Now I have these 3 files in my disk:

checkpoint

my-model-114000

my-model-114000.meta

I tried something like this:

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('my-model-114000.meta')
    new_saver.restore(sess, 'my-model-114000')

I received the message:

INFO:tensorflow:Restoring parameters from my-model-114000

However, I am not able to restore the original parameters. I tried something like this (inside the with tf.Session() as sess)

w_h1 = tf.get_default_graph().get_tensor_by_name("w_h1:0")

but I receive the message

KeyError: "The name 'w_h1:0' refers to a Tensor which does not exist. The operation, 'w_h1', does not exist in the graph."

However, I am not able to recover the weights. How can I do that?

I used

    for var in tf.all_variables():
        print str(var) 

to know what had been saved and I realized that it saved a lot of stuff (just a sample below), but I though that I have saved only a small number of important parameters:

<tf.Variable 'Variable_21/Adam_3:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'Variable_24/Adam_2:0' shape=(50, 50) dtype=float32_ref>
<tf.Variable 'Variable_24/Adam_3:0' shape=(50, 50) dtype=float32_ref>
<tf.Variable 'Variable_25/Adam_2:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'Variable_25/Adam_3:0' shape=(50,) dtype=float32_ref>
<tf.Variable 'Variable_28/Adam_2:0' shape=(50, 1) dtype=float32_ref>
<tf.Variable 'Variable_28/Adam_3:0' shape=(50, 1) dtype=float32_ref>
<tf.Variable 'Variable_29/Adam_2:0' shape=(1,) dtype=float32_ref>
<tf.Variable 'Variable_29/Adam_3:0' shape=(1,) dtype=float32_ref>
>>> 

Upvotes: 2

Views: 149

Answers (1)

Jie.Zhou
Jie.Zhou

Reputation: 1318

names like 'Variable_21/Adam_3:0' is your variable names and "w_h1" isn't, you should get this tensor with w_h1 = tf.get_default_graph().get_tensor_by_name("Variable_21/Adam_3:0")

Upvotes: 1

Related Questions