Reputation: 634
I currently have a tensorflow Dataset
with a number of batches (Number of batches would be variable, but divisible by 4). I want to take out every 4th batch to use as testing and the rest as training, but I have yet to encounter an elegant solution. A simplified visual example of desired results:
Dataset = [b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12]
train = [b1,b2,b3,b5,b6,b7,b9,b10,b11]
test = [b4,b8,b12]
Most solutions on train-validation-test splits regarding Dataset
s use a combination of Dataset.take()
and Dataset.skip()
, as they don't mind splitting the data somewhere down the middle. If I were to use this solution, however, it would require me computing the size of the dataset, run an ugly loop over it with multiple take()
s and skip()
s, then collecting the results and concatenating them together into a new Dataset
. Is there no better way to select intervals of batches in a tensorflow dataset?
Upvotes: 0
Views: 111
Reputation: 634
The solution can be achieved through a combination of enumerate()
, filter()
, and map()
, similar to the answer provided here.
Toy example:
list(
Dataset.from_tensor_slices(np.arange(12))
.batch(2)
.as_numpy_iterator()
)
Output:
[array([0, 1]),
array([2, 3]),
array([4, 5]),
array([6, 7]),
array([8, 9]),
array([10, 11])]
Solution on toy example:
list(
Dataset.from_tensor_slices(np.arange(12))
.batch(2)
#solution starts here
.enumerate()
.filter(lambda i, data: (i+1)%4 !=0)
.map(lambda i,data: data)
#solution ends here
.as_numpy_iterator()
)
out:
[array([0, 1]),
array([2, 3]),
array([4, 5]),
array([8, 9]),
array([10, 11])]
Upvotes: 1