Reputation: 897
Does anyone know how to split a dataset created by the dataset API (tf.data.Dataset) in Tensorflow into Test and Train?
Upvotes: 73
Views: 81899
Reputation: 4165
Since TensorFlow 2.10.0 there is a tf.keras.utils.split_dataset function
, see the release notes:
Added
tf.keras.utils.split_dataset
utility to split aDataset
object or a list/tuple of arrays into twoDataset
objects (e.g. train/test).
Upvotes: 8
Reputation: 3384
Beware of lazy evaluation which produces two pipelines shuffle+take
and shuffle+skip
that do overlap. Due to this, some of the high-scored answers produce information leaks. Here is the correct way by repeating and seeding shuffle in both train and test datasets.
import tensorflow as tf
def gen_data():
return iter(range(10))
ds = tf.data.Dataset.from_generator(
gen_data,
output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))
SEED = 42 # NOTE: with no seed, you overlap train and test!
ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100)
ds_test = ds.shuffle(100,seed=SEED).skip(8)
A = set(ds_train.as_numpy_iterator())
B = set(ds_test.as_numpy_iterator())
assert A.intersection(B)==set()
print(list(A))
print(list(B))
NOTE: This works for any deterministically ordered iterator.
Upvotes: 0
Reputation: 191
A robust way to split dataset into two parts is to first deterministically map every item in the dataset into a bucket with, for example, tf.strings.to_hash_bucket_fast
. Then you can split the dataset into two by filtering by the bucket. If you split your data into five buckets, you get 80-20 split assuming that the split is even.
As an example, assume that your dataset contains dictionaries with key filename
. We split the data into five buckets based on this key. With this add_fold
function, we add the key "fold"
in the dictionaries:
def add_fold(buckets: int):
def add_(sample, label):
fold = tf.strings.to_hash_bucket(sample["filename"], num_buckets=buckets)
return {**sample, "fold": fold}, label
return add_
dataset = dataset.map(add_fold(buckets=5))
Now we can split the dataset into two disjoint datasets with Dataset.filter
:
def pick_fold(fold: int):
def filter_fn(sample, _):
return tf.math.equal(sample["fold"], fold)
return filter_fn
def skip_fold(fold: int):
def filter_fn(sample, _):
return tf.math.not_equal(sample["fold"], fold)
return filter_fn
train_dataset = dataset.filter(skip_fold(0))
val_dataset = dataset.filter(pick_fold(0))
The key that you use for hashing should be one that captures the correlations in the dataset. For example, if your samples collected by the same person are correlated and you want all samples with the same collector end up in the same bucket (and the same split), you should use the collector name or ID as the hashing column.
Of course, you can skip the part with dataset.map
and do the hashing and filtering in one filter
function. Here's a full example:
dataset = tf.data.Dataset.from_tensor_slices([f"value-{i}" for i in range(10000)])
def to_bucket(sample):
return tf.strings.to_hash_bucket_fast(sample, 5)
def filter_train_fn(sample):
return tf.math.not_equal(to_bucket(sample), 0)
def filter_val_fn(sample):
return tf.math.logical_not(filter_train_fn(sample))
train_ds = dataset.filter(filter_train_fn)
val_ds = dataset.filter(filter_val_fn)
print(f"Length of training set: {len(list(train_ds.as_numpy_iterator()))}")
print(f"Length of validation set: {len(list(val_ds.as_numpy_iterator()))}")
This prints:
Length of training set: 7995
Length of validation set: 2005
Upvotes: 0
Reputation: 8846
Most of the answers here use take()
and skip()
, which requires knowing the size of your dataset before hand. This isn't always possible, or is difficult/intensive to ascertain.
Instead what you can do is to essentially slice the dataset up so that 1 every N records becomes a validation record.
To accomplish this, lets start with a simple dataset of 0-9:
dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Now for our example, we're going to slice it so that we have a 3/1 train/validation split. Meaning 3 records will go to training, then 1 record to validation, then repeat.
split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
So the first dataset.window(split, split + 1)
says to grab split
number (3) of elements, then advance split + 1
elements, and repeat. That + 1
effectively skips the 1 element we're going to use in our validation dataset.
The flat_map(lambda ds: ds)
is because window()
returns the results in batches, which we don't want. So we flatten it back out.
Then for the validation data we first skip(split)
, which skips over the first split
number (3) of elements that were grabbed in the first training window, so we start our iteration on the 4th element. The window(1, split + 1)
then grabs 1 element, advances split + 1
(4), and repeats.
Note on nested datasets:
The above example works well for simple datasets, but flat_map()
will generate an error if the dataset is nested. To address this, you can swap out the flat_map()
with a more complicated version that can handle both simple and nested datasets:
.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
Upvotes: 33
Reputation: 183
@ted's answer will cause some overlap. Try this.
train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)
train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
use code below to test.
tf.enable_eager_execution()
dataset = tf.data.Dataset.range(100)
train_size = 20
valid_size = 30
test_size = 50
train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)
for i in train:
print(i)
for i in valid:
print(i)
for i in test:
print(i)
Upvotes: 8
Reputation: 15
Can't comment, but above answer has overlap and is incorrect. Set BUFFER_SIZE to DATASET_SIZE for perfect shuffle. Try different sized val/test size to verify. Answer should be:
DATASET_SIZE = tf.data.experimental.cardinality(full_dataset).numpy()
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = full_dataset.shuffle(BUFFER_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.take(val_size)
test_dataset = test_dataset.skip(val_size)
Upvotes: -2
Reputation: 2722
In case size of the dataset is known:
from typing import Tuple
import tensorflow as tf
def split_dataset(dataset: tf.data.Dataset,
dataset_size: int,
train_ratio: float,
validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
assert (train_ratio + validation_ratio) < 1
train_count = int(dataset_size * train_ratio)
validation_count = int(dataset_size * validation_ratio)
test_count = dataset_size - (train_count + validation_count)
dataset = dataset.shuffle(dataset_size)
train_dataset = dataset.take(train_count)
validation_dataset = dataset.skip(train_count).take(validation_count)
test_dataset = dataset.skip(validation_count + train_count).take(test_count)
return train_dataset, validation_dataset, test_dataset
Example:
size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2
ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
Upvotes: 0
Reputation: 6108
You can use shard
:
dataset = dataset.shuffle() # optional
trainset = dataset.shard(2, 0)
testset = dataset.shard(2, 1)
See: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard
Upvotes: 5
Reputation: 14764
You may use Dataset.take()
and Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
You may also want to look into Dataset.shard()
:
Creates a Dataset that includes only 1/num_shards of this dataset.
Disclaimer I stumbled upon this question after answering this one so I thought I'd spread the love
Upvotes: 59
Reputation: 1624
Assuming you have all_dataset
variable of tf.data.Dataset
type:
test_dataset = all_dataset.take(1000)
train_dataset = all_dataset.skip(1000)
Test dataset now has first 1000 elements and the rest goes for training.
Upvotes: 96
Reputation: 347
Now Tensorflow doesn't contain any tools for that.
You could use sklearn.model_selection.train_test_split
to generate train/eval/test dataset, then create tf.data.Dataset
respectively.
Upvotes: 5