Talto
Talto

Reputation: 85

How to restore a Tensorflow Model trained using Dataset API?

I'm training my model using the Dataset API with a feedable iterator, like in the Importing Data Tutorial here. The problem is, when restoring the model. It wil also restore the handle placeholder with its shape from training. That means it is expecting to get an example and a label to it.

    def loadTFRecord(filenames):          
      dataset = tf.data.TFRecordDataset([filenames])
      dataset = dataset.map(extract_img_func)
      dataset = dataset.batch(batchsize)
      handle = tf.placeholder(tf.string, shape=[])
      iterator = tf.data.Iterator.from_string_handle(handle, dataset.output_types, dataset.output_shapes)
      training_iterator = dataset.make_one_shot_iterator()
      next_element = iterator.get_next() 
      training_handle = self.sess.run(training_iterator.string_handle())
      return next_element #next_element[0] is the example img, next_element[1] is the label

    def model_fn(images, labels=None, train=False):
      input_layer = images
      ...
      predictions = last_layer
      if train:
        return predictions

      # Calculate loss
      loss = tf.losses.mean_squared_error(labels, predictions)
      learning_rate = tf.train.exponential_decay(learning_rate=learningRate, staircase=True)
      optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
      train_op = optimizer.minimize(
          loss=loss,
          global_step=global_step)

      return train_op, predictions, loss

With this i am creating my model for training:

examples, labels = loadTFRecord("path/to/tfrecord")
model_fn(examples, labels=labels)
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=0.5)
... #training here
saver.save(sess, "path/to/")

Now the problem is, when i want to restore the model for inference. What i want to do is restore the model and pass in another feedable iterator which loads some .png files from disk. I am doing this similar to loading the TFRecord file.

def load_images(filenames):
  dataset = tf.data.Dataset.from_tensor_slices(filenames)
  dataset = dataset.map(lambda x: tf.image.resize_images(self.normalize(tf.image.decode_png(tf.read_file(x), channels = 3)), [IM_WIDTH, IM_HEIGHT]))
  dataset = dataset.batch(1)
  iterator = tf.data.Iterator.from_string_handle(handle, dataset.output_types, dataset.output_shapes)
  iterator = dataset.make_one_shot_iterator()
  next_img = iterator.get_next()
  training_handle = sess.run(iterator.string_handle())
  return next_img

Now the problem is when passing it to the restored model like this:

  saver = tf.train.import_meta_graph(modelbasepath + ".meta")
  saver.restore(sess, modelbasepath)
  ... # restore operations here
  # finally run predictions, error occurs here!
  predictions = sess.run([predictions], feed_dict={handle: training_handle})

I'm getting this error:

Number of components does not match: expected 2 types but got 1.
 [[Node: IteratorFromStringHandle_2 = IteratorFromStringHandle[output_shapes=[[?,80,80,3], [?,80,80,?]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_Placeholder_1_0_0)]]

Which tells me that is is expecting to also get a label, while i am only feeding an image to predict.

How can i overcome this? Is there a way to change the shape of the placeholder or how would one implement this so it is possible to restore a model trained with the datatset API and feedable dicts?

Upvotes: 1

Views: 138

Answers (1)

DMolony
DMolony

Reputation: 643

I ran into the same problem. However, I could not come up with a clean solution. I ended up creating a dummy tensor for the label that I return when loading images. There may be a better way of doing this but this solution should now allow you to run the model.

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.map(load_images)

def load_images(x):
    image = tf.image.decode_png(tf.read_file(x), channels = 3))
    image = self.normalize(image)
    image = tf.image.resize_images(image, [IM_WIDTH, IM_HEIGHT])

    # Assuming label is one channel, can slice image to get correct dims
    label = tf.zeros_like(image[:, :, 0:1]) 

    return image, label

Upvotes: 1

Related Questions