Reputation: 1395
I have a dataset created from one tfrecord file. This dataset contains 5 different classes.
Now I want to create batches with a fixed number of elements (8 for example) from each batch. So it should create batches of 40 elements containing 8 elements of each class.
Is this possible with tf.data?
Upvotes: 2
Views: 1336
Reputation: 208
Easiest thing to do is (perhaps not very convenience) :
a) Prepare 5 different TFRecords
, each ontaining elements of only one specific class.
b) Create 5
different tf.data.TFRecordDataset
instances and hence 5
different iterators.
c) Then in the main code :
iterators = [....] # Store your iterators in a list
data = list(map(lambda x : x.get_next(), iterators))
data_to_use = tf.concat(....) # Concat your data in one single batch of `40` elements.
a) Use only one TFRecord. But create 5
different instances of it
b) In each instance, use tf.data.filter(predicate)
method of tf.data
API, to filter records, which belong to one specific class. For that you will have to write a function, which can check for the class of each record.
c) Then follow step c)
as in the previous solution.
Upvotes: 2