Derk
Derk

Reputation: 1395

Split tensorflow dataset in dataset per class

I have a dataset created from one tfrecord file. This dataset contains 5 different classes.

Now I want to create batches with a fixed number of elements (8 for example) from each batch. So it should create batches of 40 elements containing 8 elements of each class.

Is this possible with tf.data?

Upvotes: 2

Views: 1336

Answers (1)

Ujjwal
Ujjwal

Reputation: 208

Easiest thing to do is (perhaps not very convenience) :

a) Prepare 5 different TFRecords, each ontaining elements of only one specific class.

b) Create 5 different tf.data.TFRecordDataset instances and hence 5 different iterators.

c) Then in the main code :

iterators =  [....] # Store your iterators in a list
data = list(map(lambda x : x.get_next(), iterators))
data_to_use = tf.concat(....) # Concat your data in one single batch of `40` elements.

Another approach (without creating separate datasets)

a) Use only one TFRecord. But create 5 different instances of it

b) In each instance, use tf.data.filter(predicate) method of tf.data API, to filter records, which belong to one specific class. For that you will have to write a function, which can check for the class of each record.

c) Then follow step c) as in the previous solution.

Upvotes: 2

Related Questions