Reputation: 15837
The name of the method get_next()
is a little bit misleading. The documentation says
Returns a nested structure of
tf.Tensor
s representing the next element.In graph mode, you should typically call this method once and use its result as the input to another computation. A typical loop will then call
tf.Session.run
on the result of that computation. The loop will terminate when theIterator.get_next()
operation raisestf.errors.OutOfRangeError
. The following skeleton shows how to use this method when building a training loop:
dataset = ... # A `tf.data.Dataset` object.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Build a TensorFlow graph that does something with each element.
loss = model_function(next_element)
optimizer = ... # A `tf.compat.v1.train.Optimizer` object.
train_op = optimizer.minimize(loss)
with tf.compat.v1.Session() as sess:
try:
while True:
sess.run(train_op)
except tf.errors.OutOfRangeError:
pass
Python also has a function called next
, which needs to be called every time we need the next element of the iterator. However, according to the documentation of get_next()
quoted above, get_next()
should be called only once and its result should be evaluated by calling the method run
of the session, so this is a little bit unintuitive, because I was used to the Python's built-in function next
. In this script, get_next()
is also called only and the result of the call is evaluated at every step of the computation.
What is the intuition behind get_next()
and how is it different from next()
? I think that the next element of the dataset (or feedable iterator), in the second example I linked above, is retrieved every time the result of the first call to get_next()
is evaluated by calling the method run
, but this is a little unintuitive. I don't get why we do not need to call get_next
at every step of the computation (to get the next element of the feedable iterator), even after reading the note in the documentation
NOTE: It is legitimate to call
Iterator.get_next()
multiple times, e.g. when you are distributing different elements to multiple devices in a single step. However, a common pitfall arises when users callIterator.get_next()
in each iteration of their training loop.Iterator.get_next()
adds ops to the graph, and executing each op allocates resources (including threads); as a consequence, invoking it in every iteration of a training loop causes slowdown and eventual resource exhaustion. To guard against this outcome, we log a warning when the number of uses crosses a fixed threshold of suspiciousness.
In general, it is not clear how the Iterator works.
Upvotes: 3
Views: 2330
Reputation: 59691
The idea is that get_next
adds some operations to the graph such that, every time you evaluate them, you get the next element in the dataset. On each iteration, you just need to run the operations that get_next
made, you do not need to create them over and over again.
Maybe a good way to get an intuition is to try to write an iterator yourself. Consider something like the following:
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
# Make an iterator, returns next element and initializer
def iterator_next(data):
data = tf.convert_to_tensor(data)
i = tf.Variable(0)
# Check we are not out of bounds
with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
# Get next value
next_val_1 = data[i]
# Update index after the value is read
with tf.control_dependencies([next_val_1]):
i_updated = tf.compat.v1.assign_add(i, 1)
with tf.control_dependencies([i_updated]):
next_val_2 = tf.identity(next_val_1)
return next_val_2, i.initializer
# Test
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
# Example data
data = tf.constant([1, 2, 3, 4])
# Make operations that give you the next element
next_val, iter_init = iterator_next(data)
# Initialize iterator
sess.run(iter_init)
# Iterate until exception is raised
while True:
try:
print(sess.run(next_val))
# assert throws InvalidArgumentError
except tf.errors.InvalidArgumentError: break
Output:
1
2
3
4
Here, iterator_next
gives you something comparable to what get_next
in an iterator would give you, plus an initializer operation. Every time you run next_val
you get a new element from data
, you don't need to call the function every time (which is how next
works in Python), you call it once and then evaluate the result multiple times.
EDIT: The function iterator_next
above could also be simplified to the following:
def iterator_next(data):
data = tf.convert_to_tensor(data)
# Start from -1
i = tf.Variable(-1)
# First increment i
i_updated = tf.compat.v1.assign_add(i, 1)
with tf.control_dependencies([i_updated]):
# Check i is not out of bounds
with tf.control_dependencies([tf.assert_less(i, tf.shape(data)[0])]):
# Get next value
next_val = data[i]
return next_val, i.initializer
Or even simpler:
def iterator_next(data):
data = tf.convert_to_tensor(data)
i = tf.Variable(-1)
i_updated = tf.compat.v1.assign_add(i, 1)
# Using i_updated directly as a value is equivalent to using i with
# a control dependency to i_updated
with tf.control_dependencies([tf.assert_less(i_updated, tf.shape(data)[0])]):
next_val = data[i_updated]
return next_val, i.initializer
Upvotes: 2