D Myers
D Myers

Reputation: 73

When to use an iterator in Tensorflow Estimator

In the Tensorflow guides there are two separate places where the guide describes the input function for the Iris Data example. One input function returns just the dataset itself, while the other returns the dataset with an iterator.

From the premade Estimator guide: https://www.tensorflow.org/guide/premade_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)

From the custom estimator guide: https://www.tensorflow.org/guide/custom_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()

I'm confused which one is correct, and if they both are used for different cases, when is it correct to return the dataset using an iterator?

Upvotes: 7

Views: 667

Answers (1)

xdurch0
xdurch0

Reputation: 10474

If your input function returns a tf.data.Dataset, an iterator is created under the hood and its get_next() function is used to supply inputs to the model. This is somewhat hidden in the source code, see parse_input_fn_result here.

I believe this was only implemented in a more recent update, so older tutorials still explicitly return get_next() in their input function since it was the only option back then. There should be no difference between using either, but you can save a tiny bit of code by returning the dataset instead of the iterator.

Upvotes: 5

Related Questions