Reputation: 31
I'm implementing an algorithm involving alternating optimization. That is, at each iteration, the algorithm fetches a data batch, and uses the data batch to optimize two losses sequentially. My current implementation with tf.data.Dataaset
and tf.data.Iterator
is something like this (which is indeed incorrect as detailed below):
data_batch = iterator.get_next()
train_op_1 = get_train_op(data_batch)
train_op_2 = get_train_op(data_batch)
for _ in range(num_steps):
sess.run(train_op_1)
sess.run(train_op_2)
Note that the above is incorrect because each call of sess.run
will advance the iterator to get next data batch. So train_op_1
and train_op_2
are indeed using different data batches.
I cannot do something like sess.run([train_op_1, train_op_2])
either, because the two optimization steps need to be sequential (i.e., the 2nd optimization step depends on the latest variable value by the 1st optimization step.)
I'm wondering is there any way to somehow "freeze" the iterator, so that it won't advance in a sess.run
call?
Upvotes: 3
Views: 651
Reputation: 7103
I was doing something similar so that is part of my code stripped from some unnecessary stuff. It does a bit more as it has train and validation iterators, but you should get the idea of using is_keep_previous
flag. Basically passed as True
it fill force reuse of the previous value of the iterator, in case of False
it will get new value.
iterator_t = ds_t.make_initializable_iterator()
iterator_v = ds_v.make_initializable_iterator()
iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle,
iterator_t.output_types,
iterator_t.output_shapes)
def get_next_item():
# sometimes items need casting
next_elem = iterator.get_next(name="next_element")
x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]
return x, y
def old_data():
# just forward the existing batch
return inputs, target
is_keep_previous = tf.placeholder_with_default(tf.constant(False),shape=[], name="keep_previous_flag")
inputs, target = tf.cond(is_keep_previous, old_data, new_data)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
handle_t = sess.run(iterator_t.string_handle())
handle_v = sess.run(iterator_v.string_handle())
# Run data iterator initialisation
sess.run(iterator_t.initializer)
sess.run(iterator_v.initializer)
while True:
try:
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:False})
print(inputs_, target_)
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:True})
print(inputs_, target_)
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_v})
print(inputs_, target_)
except tf.errors.OutOfRangeError:
# now we know we run out of elements in the validationiterator
break
Upvotes: 1
Reputation: 5206
Use control dependencies when building the graph for train_op_2 so it can see the updated values of the variables.
Or use eager execution.
Upvotes: 0