Reputation: 73
In the Tensorflow guides there are two separate places where the guide describes the input function for the Iris Data example. One input function returns just the dataset itself, while the other returns the dataset with an iterator.
From the premade Estimator guide: https://www.tensorflow.org/guide/premade_estimators
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)
From the custom estimator guide: https://www.tensorflow.org/guide/custom_estimators
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()
I'm confused which one is correct, and if they both are used for different cases, when is it correct to return the dataset using an iterator?
Upvotes: 7
Views: 667
Reputation: 10474
If your input function returns a tf.data.Dataset
, an iterator is created under the hood and its get_next()
function is used to supply inputs to the model. This is somewhat hidden in the source code, see parse_input_fn_result
here.
I believe this was only implemented in a more recent update, so older tutorials still explicitly return get_next()
in their input function since it was the only option back then. There should be no difference between using either, but you can save a tiny bit of code by returning the dataset instead of the iterator.
Upvotes: 5