ArthurSeat
ArthurSeat

Reputation: 23

TF DATA API: How to produce tensorflow input to object set recognition

Consider this problem: select a random number of samples from a random subject in an image dataset (like ImageNet) as an input element for Tensorflow graph which functions as an object set recognizer. For each batch, each class has a same number of samples to facilitate computation. But a different batch would have a different number of images for one class, i.e. batch_0:num_imgs_per_cls=2; batch_1000:num_imgs_per_cls=3.

If there is existing functionality in Tensorflow, explanation for the whole process from scratch (like from directories of images) will be really appreciated.

Upvotes: 2

Views: 922

Answers (1)

Olivier Moindrot
Olivier Moindrot

Reputation: 28218

There is a very similar answer by @mrry here.

Sampling balanced batches

In face recognition we often use triplet loss (or similar losses) to train the model. The usual way to sample triplets to compute the loss is to create a balanced batch of images where we have for instance 10 different classes (i.e. 10 different people) with 5 images each. This gives a total batch size of 50 in this example.

More generally the problem is to sample num_classes_per_batch (10 in the example) classes, and then sample num_images_per_class (5 in the example) images for each class. The total batch size is:

batch_size = num_classes_per_batch * num_images_per_class

Have one dataset for each class

The easiest way to deal with a lot of different classes (100,000 in MS-Celeb) is to create one dataset for each class.
For instance you can have one tfrecord for each class and create the datasets like this:

# Build one dataset per class.
filenames = ["class_0.tfrecords", "class_1.tfrecords"...]
per_class_datasets = [tf.data.TFRecordDataset(f).repeat(None) for f in filenames]

Sample from the datasets

Now we would like to be able to sample from these datasets. For instance we want the following labels in our batch:

1 1 1 3 3 3 9 9 9 4 4 4

This corresponds to num_classes_per_batch=4 and num_images_per_class=3.

To do this we will need to use features that will be released in r1.9. The function should be called tf.contrib.data.choose_from_datasets (see here for a discussion on this).
It should look like:

def choose_from_datasets(datasets, selector):
    """Chooses elements with indices from selector among the datasets in `datasets`."""

So we create this selector which will output 1 1 1 3 3 3 9 9 9 4 4 4 and combine it with datasets to obtain our final dataset that will output balanced batches:

def generator(_):
    # Sample `num_classes_per_batch` classes for the batch
    sampled = tf.random_shuffle(tf.range(num_classes))[:num_classes_per_batch]
    # Repeat each element `num_images_per_class` times
    batch_labels = tf.tile(tf.expand_dims(sampled, -1), [1, num_images_per_class])
    return tf.to_int64(tf.reshape(batch_labels, [-1]))

selector = tf.contrib.data.Counter().map(generator)
selector = selector.apply(tf.contrib.data.unbatch())

dataset = tf.contrib.data.choose_from_datasets(datasets, selector)

# Batch
batch_size = num_classes_per_batch * num_images_per_class
dataset = dataset.batch(batch_size)

You can test this with the nightly TensorFlow build and by using DirectedInterleaveDataset as a workaround:

# The working option right now is 
from tensorflow.contrib.data.python.ops.interleave_ops import DirectedInterleaveDataset
dataset = DirectedInterleaveDataset(selector, datasets)

I also wrote about this workaround here.

Upvotes: 5

Related Questions