Jeff
Jeff

Reputation: 634

Split and Recombine Tensorflow Dataset

I currently have a tensorflow Dataset with a number of batches (Number of batches would be variable, but divisible by 4). I want to take out every 4th batch to use as testing and the rest as training, but I have yet to encounter an elegant solution. A simplified visual example of desired results:

Dataset = [b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12]
train = [b1,b2,b3,b5,b6,b7,b9,b10,b11]
test = [b4,b8,b12]

Most solutions on train-validation-test splits regarding Datasets use a combination of Dataset.take() and Dataset.skip(), as they don't mind splitting the data somewhere down the middle. If I were to use this solution, however, it would require me computing the size of the dataset, run an ugly loop over it with multiple take()s and skip()s, then collecting the results and concatenating them together into a new Dataset. Is there no better way to select intervals of batches in a tensorflow dataset?

Upvotes: 0

Views: 111

Answers (1)

Jeff
Jeff

Reputation: 634

The solution can be achieved through a combination of enumerate(), filter(), and map(), similar to the answer provided here.

Toy example:

list(
    Dataset.from_tensor_slices(np.arange(12))
    .batch(2)
    .as_numpy_iterator()
)

Output:

[array([0, 1]),
 array([2, 3]),
 array([4, 5]),
 array([6, 7]),
 array([8, 9]),
 array([10, 11])]

Solution on toy example:

list(
    Dataset.from_tensor_slices(np.arange(12))
    .batch(2)
    #solution starts here
    .enumerate() 
    .filter(lambda i, data: (i+1)%4 !=0)
    .map(lambda i,data: data)
    #solution ends here
    .as_numpy_iterator()
)

out:

[array([0, 1]), 
 array([2, 3]), 
 array([4, 5]),
 array([8, 9]),
 array([10, 11])]

Upvotes: 1

Related Questions