Michael
Michael

Reputation: 5897

How to switch model iterator between train and validate datasets?

I'm learning TensorFlow "lower API", where you manually specify layers using tf.layers, create datasets and iterators, and run the loops to train and validate the model. I am trying to run training and validation. Unfortunately, I am running into errors when trying to switch between training and validation datasets:

Here's what I have:

self.train_it = \
    train_dataset.batch(self.batch_size).make_initializable_iterator()
self.validate_it = \
    train_dataset.batch(self.batch_size).make_initializable_iterator()

...

input_layer = self.train_it.get_next()[0]
hidden1 = tf.layers.dense(
    input_layer,
    ... )

...

with tf.name_scope('train'):
  self.train_op = \
        tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)

...

for epo in range(epochs):
  # Train using self.train_it iterator.
  self.sess.run(self.train_it.initializer)
  total_loss = 0
  for iteration in range(n_batches):
    summary, _, batch_loss = self.sess.run([self.merged_summary, \
        self.train_op, self.loss])
    total_loss += batch_loss
  print('   Epoch : {}/{}, Training loss = {:.4f}'. \
            format(epo+1, epochs, total_loss / n_batches))
  # Validate using self.valid_it iterator.
  self.sess.run(self.validate_it.initializer)
  # HOW DO I TELL THE MODEL TO USE self.valid_it INSTEAD OF self.train_it ???

The problem here is that in the beginning I already told the model to use train_it : input_layer = self.train_it.get_next()[0] , and now I have to tell it to switch between train_it and validate_it every epoch. I must be missing something in the API on how to do that.

Upvotes: 0

Views: 145

Answers (1)

gorjan
gorjan

Reputation: 5555

I would use reinitializable iterator and do the following.

train_dataset = train_dataset.batch(batch_size_train)
val_dataset = validation_dataset.batch(batch_size_val)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

train_init_op = iterator.make_initializer(train_dataset)
val_init_op = iterator.make_initializer(val_dataset)

data, labels = iterator.get_next()

Then link the data and the labels in the model. Afterwards while training do the following:

for e in range(epochs):
    sess.run(train_init_op)
    for iteration in range(n_batches_val):
        ....
    sess.run(val_init_op)
    for iteration in range(n_batches_val):
        ....

Let me know if you find something confusing.

Upvotes: 1

Related Questions