nbro
nbro

Reputation: 15837

What is the intuition behind the Iterator.get_next method?

The name of the method get_next() is a little bit misleading. The documentation says

Returns a nested structure of tf.Tensors representing the next element.

In graph mode, you should typically call this method once and use its result as the input to another computation. A typical loop will then call tf.Session.run on the result of that computation. The loop will terminate when the Iterator.get_next() operation raises tf.errors.OutOfRangeError. The following skeleton shows how to use this method when building a training loop:

dataset = ...  # A `tf.data.Dataset` object.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Build a TensorFlow graph that does something with each element.
loss = model_function(next_element)
optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
train_op = optimizer.minimize(loss)

with tf.compat.v1.Session() as sess:
  try:
    while True:
      sess.run(train_op)
  except tf.errors.OutOfRangeError:
    pass

Python also has a function called next, which needs to be called every time we need the next element of the iterator. However, according to the documentation of get_next() quoted above, get_next() should be called only once and its result should be evaluated by calling the method run of the session, so this is a little bit unintuitive, because I was used to the Python's built-in function next. In this script, get_next() is also called only and the result of the call is evaluated at every step of the computation.

What is the intuition behind get_next() and how is it different from next()? I think that the next element of the dataset (or feedable iterator), in the second example I linked above, is retrieved every time the result of the first call to get_next() is evaluated by calling the method run, but this is a little unintuitive. I don't get why we do not need to call get_next at every step of the computation (to get the next element of the feedable iterator), even after reading the note in the documentation

NOTE: It is legitimate to call Iterator.get_next() multiple times, e.g. when you are distributing different elements to multiple devices in a single step. However, a common pitfall arises when users call Iterator.get_next() in each iteration of their training loop. Iterator.get_next() adds ops to the graph, and executing each op allocates resources (including threads); as a consequence, invoking it in every iteration of a training loop causes slowdown and eventual resource exhaustion. To guard against this outcome, we log a warning when the number of uses crosses a fixed threshold of suspiciousness.

In general, it is not clear how the Iterator works.

Upvotes: 3

Views: 2330

Answers (1)

javidcf
javidcf

Reputation: 59691

The idea is that get_next adds some operations to the graph such that, every time you evaluate them, you get the next element in the dataset. On each iteration, you just need to run the operations that get_next made, you do not need to create them over and over again.

Maybe a good way to get an intuition is to try to write an iterator yourself. Consider something like the following:

import tensorflow as tf
tf.compat.v1.disable_v2_behavior()

# Make an iterator, returns next element and initializer
def iterator_next(data):
    data = tf.convert_to_tensor(data)
    i = tf.Variable(0)
    # Check we are not out of bounds
    with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
        # Get next value
        next_val_1 = data[i]
    # Update index after the value is read
    with tf.control_dependencies([next_val_1]):
        i_updated = tf.compat.v1.assign_add(i, 1)
        with tf.control_dependencies([i_updated]):
            next_val_2 = tf.identity(next_val_1)
    return next_val_2, i.initializer

# Test
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
    # Example data
    data = tf.constant([1, 2, 3, 4])
    # Make operations that give you the next element
    next_val, iter_init = iterator_next(data)
    # Initialize iterator
    sess.run(iter_init)
    # Iterate until exception is raised
    while True:
        try:
            print(sess.run(next_val))
        # assert throws InvalidArgumentError
        except tf.errors.InvalidArgumentError: break

Output:

1
2
3
4

Here, iterator_next gives you something comparable to what get_next in an iterator would give you, plus an initializer operation. Every time you run next_val you get a new element from data, you don't need to call the function every time (which is how next works in Python), you call it once and then evaluate the result multiple times.

EDIT: The function iterator_next above could also be simplified to the following:

def iterator_next(data):
    data = tf.convert_to_tensor(data)
    # Start from -1
    i = tf.Variable(-1)
    # First increment i
    i_updated = tf.compat.v1.assign_add(i, 1)
    with tf.control_dependencies([i_updated]):
        # Check i is not out of bounds
        with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
            # Get next value
            next_val = data[i]
    return next_val, i.initializer

Or even simpler:

def iterator_next(data):
    data = tf.convert_to_tensor(data)
    i = tf.Variable(-1)
    i_updated = tf.compat.v1.assign_add(i, 1)
    # Using i_updated directly as a value is equivalent to using i with
    # a control dependency to i_updated
    with tf.control_dependencies([tf.assert_less(i_updated, tf.shape(data)[0])]):
        next_val = data[i_updated]
    return next_val, i.initializer

Upvotes: 2

Related Questions