Reputation: 894
After training a model in tensorflow, it is saved as following:
saver = tf.train.Saver()
saver.save(sess,'myModel/Path/Model_1')
Generating files called:
Now to restore the model after creating a new session, and initializing the tensorflow graph in exactly the same way as originally created, I restore it as follows:
sess = tf.Session()
# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()
sess.run(init)
imported_meta = tf.train.Saver()
imported_meta.restore(sess,'myModel/Path/Model_1.meta')
Which throws the following error:
InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [6152,32] rhs shape= [6164,80]
[[Node: save_2/Assign_3 = Assign[T=DT_FLOAT, _class=["loc:@DGNS/bidirectional_rnn/bw/basic_lstm_cell/kernel"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](DGNS/bidirectional_rnn/bw/basic_lstm_cell/kernel, save_2/RestoreV2/_111)]]
Caused by op u'save_2/Assign_3', defined at:
File "/usr/lib/python2.7/dist-packages/spyderlib/widgets /externalshell/start_ipython_kernel.py", line 205, in <module>
__ipythonkernel__.start()
"/usr/lib/python2.7/dist-packages/IPython/kernel/zmq/kernelapp.py", line 459, in start
ioloop.IOLoop.instance().start()
File "/usr/lib/python2.7/dist-packages/zmq/eventloop/ioloop.py", line 162, in start
super(ZMQIOLoop, self).start()
File "/usr/lib/python2.7/dist-packages/zmq/eventloop/minitornado/ioloop.py", line 830, in start
self._run_callback(callback)
File "/usr/lib/python2.7/dist-packages/zmq/eventloop/minitornado/ioloop.py", line 603, in _run_callback
ret = callback()
... ... etc
I need help understanding what is happening here. The error hints to some shape mismatch issue. But I do not understand how this can be as I have used exactly the same code for generating the model and initializing a new graph. The only difference in the code is the model loading part.
How can I start debugging this error in order to get a hint on how to load my model correclty?
Upvotes: 0
Views: 602
Reputation: 1204
I'm pretty sure you are not supposed to load the .meta file. It's tricky to understand since it outputs 3 different files for the checkpoints. Try this:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(
'myModel/Path/Model_1.meta', clear_devices=True)
new_saver.restore(sess, 'myModel/Path/Model_1')
Also, just for clarification, are you storing your full model in a .pb file as well, or just generating these checkpoints?
Upvotes: 1