Nishanth Dikkala
Nishanth Dikkala

Reputation: 13

How to extract samples from a Tensorflow Dataset which have the same label?

I want to generate batched samples from a TF Dataset in the following manner: each batch would consist of two samples with have the same 'label' feature. What is the most efficient way to achieve this in Tensorflow?

Upvotes: 1

Views: 627

Answers (1)

AloneTogether
AloneTogether

Reputation: 26708

Let's assume you have some kind of data with multiple labels like this:

x = tf.random.uniform((10, 2), maxval=20, dtype=tf.int32)
y = tf.random.uniform((10, ), maxval=4, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x, y))

In order to achieve your goal. "each batch would consist of two samples with have the same 'label' feature" you could use the filter function to create separate datasets for each label and then the concatenate function of the tf.data.Dataset API to merge these datasets into one:

x = tf.random.uniform((10, 2), maxval=20, dtype=tf.int32)
y = tf.random.uniform((10, ), maxval=4, dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((x, y))

batch_size = 2
dataset0 = dataset.filter(lambda x, y: tf.equal(y, 0)).batch(batch_size)
dataset1 = dataset.filter(lambda x, y: tf.equal(y, 1)).batch(batch_size)
dataset2 = dataset.filter(lambda x, y: tf.equal(y, 2)).batch(batch_size)
dataset3 = dataset.filter(lambda x, y: tf.equal(y, 3)).batch(batch_size)

dataset = dataset0.concatenate(dataset1).concatenate(dataset2).concatenate(dataset3).shuffle(buffer_size=20)

Upvotes: 1

Related Questions