Reputation: 1033
I have a tensorflow dataset based on one .tfrecord file. How do I split the dataset into test and train datasets? E.g. 70% Train and 30% test?
Edit:
My Tensorflow Version: 1.8 I've checked, there is no "split_v" function as mentioned in the possible duplicate. Also I am working with a tfrecord file.
Upvotes: 42
Views: 54508
Reputation: 3354
I will first explain why the accepted answer is wrong and secondly will provide a simple working solution, using take()
, skip()
and seed
.
When working with pipelines, such as TF/Torch Datasets, beware of lazy evaluation. Avoid:
# DONT
full_dataset = full_dataset.shuffle(10)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
because take and skip will synchronize to single shuffle, but rather gets executed as shuffle+take
and shuffle+skip
separately (yes !), overlapping typically in 80%*20%=16% of cases. So, information leak.
Play with this code in case of doubt
import tensorflow as tf
def gen_data():
return iter(range(10))
full_dataset = tf.data.Dataset.from_generator(
gen_data,
output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))
train_size = 8
# WRONG WAY
full_dataset = full_dataset.shuffle(10)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
A = set(train_dataset.as_numpy_iterator())
B = set(test_dataset.as_numpy_iterator())
# EXPECT OVERLAP
assert A.intersection(B)==set()
print(list(A))
print(list(B))
Now, what works is repeating and seeding shuffle in both train and test datasets, which is also good for reproducibility. This should work with any deterministically ordered iterator:
import tensorflow as tf
def gen_data():
return iter(range(10))
ds = tf.data.Dataset.from_generator(
gen_data,
output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))
SEED = 42 # NOTE: change this
ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100)
ds_test = ds.shuffle(100,seed=SEED).skip(8)
A = set(ds_train.as_numpy_iterator())
B = set(ds_test.as_numpy_iterator())
assert A.intersection(B)==set()
print(list(A))
print(list(B))
By playing with SEED
you can for instance inspect/estimate generalization (bootstraping in place of cross-validation).
Upvotes: 2
Reputation: 5949
This question is similar to this one and this one, and I am afraid we have not had a satisfactory answer yet.
Using take()
and skip()
requires knowing the dataset size. What if I don't know that, or don't want to find out?
Using shard()
only gives 1 / num_shards
of dataset. What if I want the rest?
I try to present a better solution below, tested on TensorFlow 2 only. Assuming you already have a shuffled dataset, you can then use filter()
to split it into two:
import tensorflow as tf
all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
.shuffle(10, reshuffle_each_iteration=False)
test_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 == 0) \
.map(lambda x,y: y)
train_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 != 0) \
.map(lambda x,y: y)
for i in test_dataset:
print(i)
print()
for i in train_dataset:
print(i)
The parameter reshuffle_each_iteration=False
is important. It makes sure the original dataset is shuffled once and no more. Otherwise, the two resulting sets may have some overlaps.
Use enumerate()
to add an index.
Use filter(lambda x,y: x % 4 == 0)
to take 1 sample out of 4. Likewise, x % 4 != 0
takes 3 out of 4.
Use map(lambda x,y: y)
to strip the index and recover the original sample.
This example achieves a 75/25 split.
x % 5 == 0
and x % 5 != 0
gives a 80/20 split.
If you really want a 70/30 split, x % 10 < 3
and x % 10 >= 3
should do.
UPDATE:
As of TensorFlow 2.0.0, above code may result in some warnings due to AutoGraph's limitations. To eliminate those warnings, declare all lambda functions separately:
def is_test(x, y):
return x % 4 == 0
def is_train(x, y):
return not is_test(x, y)
recover = lambda x,y: y
test_dataset = all.enumerate() \
.filter(is_test) \
.map(recover)
train_dataset = all.enumerate() \
.filter(is_train) \
.map(recover)
This gives no warning on my machine. And making is_train()
to be not is_test()
is definitely a good practice.
Upvotes: 44
Reputation: 14734
You may use Dataset.take()
and Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
You may also want to look into Dataset.shard()
:
Creates a Dataset that includes only 1/num_shards of this dataset.
Upvotes: 56