Esi
Esi

Reputation: 33

Split a dataset issue in Tensorflow dataset API

I am reading a csv file using tf.contrib.data.make_csv_dataset to form a dataset, and then I use the command take() to form another dataset with just one element, but still it returns all elments.

What is wrong here? I brought the code below:

import tensorflow as tf
import os
tf.enable_eager_execution()

# Constants

column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']
batch_size   = 1
feature_names = column_names[:-1]
label_name = column_names[-1]

# to reorient data strucute
def pack_features_vector(features, labels):
  """Pack the features into a single array."""
  features = tf.stack(list(features.values()), axis=1)
  return features, labels

# Download the file
train_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"
train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),
                                       origin=train_dataset_url)

# form the dataset
train_dataset = tf.contrib.data.make_csv_dataset(
train_dataset_fp,
batch_size, 
column_names=column_names,
label_name=label_name,
num_epochs=1)

# perform the mapping
train_dataset = train_dataset.map(pack_features_vector)

# construct a databse with one element 
train_dataset= train_dataset.take(1)

# inspect elements
for step in range(10):
    features, labels = next(iter(train_dataset))
    print(list(features))

Upvotes: 2

Views: 580

Answers (1)

Amir
Amir

Reputation: 16607

Based on this answer we can split Dataset with Dataset.take() and Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)

train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)

How to fix your code?

Instead of creating the iterator multiple times in the loop, use one iterator:

# inspect elements
for feature, label in train_dataset:
    print(feature)

What happens in your code that causes such behavior?

1) Built-in python iter function gets an iterator from an object or the object itself must supply its own iterator. So when you call iter(train_dataset), it is equavalent to call Dataset.make_one_shot_iterator().

2) By default, in tf.contrib.data.make_csv_dataset() the shuffle argument is True (shuffle=True). As a result, each time you call iter(train_dataset) it creates new Iterator that contains different data.

3) Finally, when looping through by for step in range(10) it is similar that you create 10 different Iterator with the size of 1 that each one has its own data because they are shuffled.

Suggestion: If you want to avoid such things initialize (create) iterator outside of loop:

train_dataset = train_dataset.take(1)
iterator = train_dataset.make_one_shot_iterator()
# inspect elements
for step in range(10):
    features, labels = next(iterator)
    print(list(features))
    # throws exception because size of iterator is 1

Upvotes: 1

Related Questions