Necromancer
Necromancer

Reputation: 929

Importing tf.train.Saver from another python file

I use tf.train.Saver() in one.py file with the following code.

saver = tf.train.Saver(tf.all_variables())
saver.save(sess,"checkpoint.data")

How can I restore checkpoint.data in another python file?

I used the following code, but it didn't work.

from one import saver
import tensorflow as tf

with tf.Session() as sess:
    saver.restore(sess, "checkpoint.data")

Upvotes: 0

Views: 282

Answers (1)

mrry
mrry

Reputation: 126194

The checkpoint file (i.e. 'checkpoint.data') does not provide TensorFlow with enough information to reconstruct your model structure. In your second program, you need to reconstruct the same TensorFlow graph that was used in the first program. There are a few options for doing this:

  • Extract the model building code into a Python function, and call it before creating the tf.train.Saver in each program.
  • Use saver.export_meta_graph() to write out the graph structure along with a checkpoint in your first program, and tf.train.import_meta_graph() to import the graph structure (and create an appropriately configure tf.train.Saver instance) in your second program.

Upvotes: 1

Related Questions