Hxyz
Hxyz

Reputation: 31

How to reuse a data batch from iterator.get_next()

I'm implementing an algorithm involving alternating optimization. That is, at each iteration, the algorithm fetches a data batch, and uses the data batch to optimize two losses sequentially. My current implementation with tf.data.Dataaset and tf.data.Iterator is something like this (which is indeed incorrect as detailed below):

data_batch = iterator.get_next()
train_op_1 = get_train_op(data_batch)
train_op_2 = get_train_op(data_batch)

for _ in range(num_steps):
    sess.run(train_op_1)
    sess.run(train_op_2)

Note that the above is incorrect because each call of sess.run will advance the iterator to get next data batch. So train_op_1 and train_op_2 are indeed using different data batches.

I cannot do something like sess.run([train_op_1, train_op_2]) either, because the two optimization steps need to be sequential (i.e., the 2nd optimization step depends on the latest variable value by the 1st optimization step.)

I'm wondering is there any way to somehow "freeze" the iterator, so that it won't advance in a sess.run call?

Upvotes: 3

Views: 651

Answers (2)

MPękalski
MPękalski

Reputation: 7103

I was doing something similar so that is part of my code stripped from some unnecessary stuff. It does a bit more as it has train and validation iterators, but you should get the idea of using is_keep_previous flag. Basically passed as True it fill force reuse of the previous value of the iterator, in case of False it will get new value.

iterator_t = ds_t.make_initializable_iterator()
iterator_v = ds_v.make_initializable_iterator()

iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle, 
                                               iterator_t.output_types,
                                               iterator_t.output_shapes)

def get_next_item():
  # sometimes items need casting
  next_elem = iterator.get_next(name="next_element")
  x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]
  return x, y  

def old_data():
        # just forward the existing batch
        return inputs, target

is_keep_previous = tf.placeholder_with_default(tf.constant(False),shape=[], name="keep_previous_flag")

inputs, target =  tf.cond(is_keep_previous, old_data, new_data)

with tf.Session() as sess:
 sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])        
 handle_t = sess.run(iterator_t.string_handle())
 handle_v = sess.run(iterator_v.string_handle())
 # Run data iterator initialisation
 sess.run(iterator_t.initializer)
 sess.run(iterator_v.initializer)
 while True:
   try:
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:False})
     print(inputs_, target_)
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:True})
     print(inputs_, target_)
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_v})
     print(inputs_, target_)
   except tf.errors.OutOfRangeError:
     # now we know we run out of elements in the validationiterator
     break

Upvotes: 1

Alexandre Passos
Alexandre Passos

Reputation: 5206

Use control dependencies when building the graph for train_op_2 so it can see the updated values of the variables.

Or use eager execution.

Upvotes: 0

Related Questions