Omar Hommos
Omar Hommos

Reputation: 163

Restoring a Tensorflow model that uses Iterators

I have a model that's trains my network using an Iterator; following the new Dataset API pipeline model that's now recommended by Google.

I read tfrecord files, feed data to the network, train nicely, and all is going well, I save my model in the end of the training so I can run Inference on it later. A simplified version of the code is as following:

""" Training and saving """

training_dataset = tf.contrib.data.TFRecordDataset(training_record)
training_dataset = training_dataset.map(ds._path_records_parser)
training_dataset = training_dataset.batch(BATCH_SIZE)
with tf.name_scope("iterators"):
  training_iterator = Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
  next_training_element = training_iterator.get_next()
  training_init_op = training_iterator.make_initializer(training_dataset)

def train(num_epochs):
  # compute for the number of epochs
  for e in range(1, num_epochs+1):
    session.run(training_init_op) #initializing iterator here
    while True:
      try:
        images, labels = session.run(next_training_element)
        session.run(optimizer, feed_dict={x: images, y_true: labels})
      except tf.errors.OutOfRangeError:
        saver_name = './saved_models/ucf-model'
        print("Finished Training Epoch {}".format(e))
        break



    """ Restoring """
# restoring the saved model and its variables
session = tf.Session()
saver = tf.train.import_meta_graph(r'saved_models\ucf-model.meta')
saver.restore(session, tf.train.latest_checkpoint('.\saved_models'))
graph = tf.get_default_graph()

# restoring relevant tensors/ops
accuracy = graph.get_tensor_by_name("accuracy/Mean:0") #the tensor that when evaluated returns the mean accuracy of the batch
testing_iterator = graph.get_operation_by_name("iterators/Iterator") #my iterator used in testing.
next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") #the GetNext operator for my iterator
# loading my testing set tfrecords
testing_dataset = tf.contrib.data.TFRecordDataset(testing_record_path)
testing_dataset = testing_dataset.map(ds._path_records_parser, num_threads=4, output_buffer_size=BATCH_SIZE*20)
testing_dataset = testing_dataset.batch(BATCH_SIZE)

testing_init_op = testing_iterator.make_initializer(testing_dataset) #to initialize the dataset

with tf.Session() as session:
  session.run(testing_init_op)
  while True:
    try:
      images, labels = session.run(next_testing_element)
      accuracy = session.run(accuracy, feed_dict={x: test_images, y_true: test_labels}) #error here, x, y_true not defined
    except tf.errors.OutOfRangeError:
      break

My problem is mainly when I restore the model. How to feed testing data to the network?

Another issue is, since an iterator is being used, there's no need to use placeholders in the training_model, as an iterator feed data directly to the graph. But this way, how to restore my feed_dict keys in the 3rd to last line, when I feed data to the "accuracy" op?

EDIT: if someone could suggest a way to add placeholders between the Iterator and the network input, then I could try running the graph by evaluating the "accuracy" tensor while feeding data to the placeholders and ignoring the iterator altogether.

Upvotes: 11

Views: 5592

Answers (4)

Aga
Aga

Reputation: 179

I would suggest having a look at CheckpointInputPipelineHook CheckpointInputPipelineHook, which implements saving iterator state for further training with tf.Estimator.

Upvotes: 0

P-Gn
P-Gn

Reputation: 24651

I would suggest to use tf.contrib.data.make_saveable_from_iterator, which has been designed precisely for this purpose. It is much less verbose and does not require you to change existing code, in particular how you define your iterator.

Working example, when we save everything after step 5 has completed. Note how I don't even bother knowing what seed is used.

import tensorflow as tf

iterator = (
  tf.data.Dataset.range(100)
  .shuffle(10)
  .make_one_shot_iterator())
batch = iterator.get_next(name='batch')

saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
saver = tf.train.Saver()

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  for step in range(10):
    print('{}: {}'.format(step, sess.run(batch)))
    if step == 5:
      saver.save(sess, './foo', global_step=step)

# 0: 1
# 1: 6
# 2: 7
# 3: 3
# 4: 8
# 5: 10
# 6: 12
# 7: 14
# 8: 5
# 9: 17

Then later, if we resume from step 6, we get the same output.

import tensorflow as tf

saver = tf.train.import_meta_graph('./foo-5.meta')
with tf.Session() as sess:
  saver.restore(sess, './foo-5')
  for step in range(6, 10):
    print('{}: {}'.format(step, sess.run('batch:0')))
# 6: 12
# 7: 14
# 8: 5
# 9: 17

Upvotes: 4

metastableB
metastableB

Reputation: 812

When restoring a saved meta graph, you can restore the initialization operation with name and then use it again to initialize the input pipeline for inference.

That is, when creating the graph, you can do

    dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

And then restore this operation by doing:

    dataset_init_op = graph.get_operation_by_name('dataset_init')

Here is a self contained code snippet that compares results of a randomly initialized model before and after restoring.

Saving an Iterator

np.random.seed(42)
data = np.random.random([4, 4])
X = tf.placeholder(dtype=tf.float32, shape=[4, 4], name='X')
dataset = tf.data.Dataset.from_tensor_slices(X)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_next_op = iterator.get_next()

# name the operation
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

w = np.random.random([1, 4])
W = tf.Variable(w, name='W', dtype=tf.float32)
output = tf.multiply(W, dataset_next_op, name='output')     
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
sess.run(dataset_init_op, feed_dict={X:data})
while True:
    try:
        print(sess.run(output))
    except tf.errors.OutOfRangeError:
        saver.save(sess, 'tmp/', global_step=1002)
    break

And then you can restore the same model for inference as follows:

Restoring saved iterator

np.random.seed(42)
data = np.random.random([4, 4])
tf.reset_default_graph()
sess = tf.Session()
saver = tf.train.import_meta_graph('tmp/-1002.meta')
ckpt = tf.train.get_checkpoint_state(os.path.dirname('tmp/checkpoint'))
saver.restore(sess, ckpt.model_checkpoint_path)
graph = tf.get_default_graph()

# Restore the init operation
dataset_init_op = graph.get_operation_by_name('dataset_init')

X = graph.get_tensor_by_name('X:0')
output = graph.get_tensor_by_name('output:0')
sess.run(dataset_init_op, feed_dict={X:data})
while True:
try:
    print(sess.run(output))
except tf.errors.OutOfRangeError:
    break

Upvotes: 12

Omar Hommos
Omar Hommos

Reputation: 163

I couldn't solve the problem related to initializing the iterator, but since I pre-process my dataset using map method, and I apply transformations defined by Python operations wrapped with py_func, which cannot be serialized for storing\restoring, I'll have to initialize my dataset when I want to restore it anyway.

So, the problem that remains is how to feed data to my graph when I restore it. I placed a tf.identity node between the iterator output and my network input. Upon restoring, I feed my data to the identity node. A better solution that I discovered later is using placeholder_with_default(), as described in this answer.

Upvotes: 1

Related Questions