pjao4512
pjao4512

Reputation: 13

TensorFlow Entire Datatset stored in Graph

I'm working on developing a CNN with the Cifar-10 dataset and to feed the data to the network, I am using the Dataset API to use feedable iterators with the handle placeholders: https://www.tensorflow.org/programmers_guide/datasets#creating_an_iterator. Personally I really like this method because it provides a clear and simple way to feed data to the network and switch between my testing and validation sets. However, when I save the graph at the end of training, the .meta file created is as large as the testing data I started with. I am using these operations to provide access later to the input placeholders and output operators:

tf.get_collection("validation_nodes")
tf.add_to_collection("validation_nodes", input_data)
tf.add_to_collection("validation_nodes", input_labels)
tf.add_to_collection("validation_nodes", predict)

And then use the following to save the graph: Before training:

saver = tf.train.Saver()

After training:

save_path = saver.save(sess, "./my_model")

Is there a way to prevent TensorFlow from storing all the data in the graph? Thanks in advance!

Upvotes: 1

Views: 339

Answers (1)

David Parks
David Parks

Reputation: 32071

You're creating a tf.constant for the dataset which is why it's added to the graph definition. The solution is to use an initializable iterator and define a placeholder. The first thing you do before you start running operations against the graph is to feed it the dataset. See the programmers guide under the "creating an iterator" section for an example.

https://www.tensorflow.org/programmers_guide/datasets#creating_an_iterator

I do exactly the same, so here is a copy/paste of the relevant parts of code that I use to achieve exactly your description (train/test sets of cifar10 using an initializable iterator):

  def build_datasets(self):
    """ Creates a train_iterator and test_iterator from the two datasets. """
    self.imgs_4d_uint8_placeholder = tf.placeholder(tf.uint8, [None, 32, 32, 3], 'load_images_placeholder')
    self.imgs_4d_float32_placeholder = tf.placeholder(tf.float32, [None, 32, 32, 3], 'load_images_float32_placeholder')
    self.labels_1d_uint8_placeholder = tf.placeholder(tf.uint8, [None], 'load_labels_placeholder')
    self.load_data_train = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_uint8_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })
    self.load_data_test = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_uint8_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })
    self.load_data_adversarial = tf.data.Dataset.from_tensor_slices({
      'data': self.imgs_4d_float32_placeholder,
      'labels': self.labels_1d_uint8_placeholder
    })

    # Train dataset pipeline
    dataset_train = self.load_data_train
    dataset_train = dataset_train.shuffle(buffer_size=50000)
    dataset_train = dataset_train.repeat()
    dataset_train = dataset_train.map(self._img_augmentation, num_parallel_calls=8)
    dataset_train = dataset_train.map(self._img_preprocessing, num_parallel_calls=8)
    dataset_train = dataset_train.batch(self.hyperparams['batch_size'])
    dataset_train = dataset_train.prefetch(2)
    self.iterator_train = dataset_train.make_initializable_iterator()

    # Test dataset pipeline
    dataset_test = self.load_data_test
    dataset_test = dataset_test.map(self._img_preprocessing, num_parallel_calls=8)
    dataset_test = dataset_test.batch(self.hyperparams['batch_size'])
    self.iterator_test = dataset_test.make_initializable_iterator()



  def init(self, sess):
    self.cifar10 = Cifar10()    # a class I wrote for loading cifar10
    self.handle_train = sess.run(self.iterator_train.string_handle())
    self.handle_test = sess.run(self.iterator_test.string_handle())
    sess.run(self.iterator_train.initializer, feed_dict={self.handle: self.handle_train,
                                                         self.imgs_4d_uint8_placeholder: self.cifar10.train_data,
                                                         self.labels_1d_uint8_placeholder: self.cifar10.train_labels})

Upvotes: 1

Related Questions