benjaminplanche
benjaminplanche

Reputation: 15119

Tensorflow Dataset - How to build batchs given a generator outputting X inputs for 1 label?

Short Version

Given a generator sampling e.g. 3 inputs and 1 label, how can I define my Tensorflow Dataset pipeline to obtain batches of K * 3 inputs and K * 1 labels?


Longer Version

Context

I am using a Triplet network, and want to adapt my current input pipeline to use Tensorflow Dataset.

In my case, a batch consists of N (e.g. images) and N // 3 labels (supposing N % 3 == 0), with each label applied to 3 consecutive inputs, e.g.

labels = [compute_label(inputs[3*i], inputs[3*i+1], inputs[3*i+2]) for i in range(N // 3)]

with compute_label(*args) a simple function, which can be implemented either with Tensorflow operations or basic Python.

To make things a bit more complicated, the input elements must be sampled 3 by 3 (e.g. we want inputs[3*i] to be similar to inputs[3*i+1] and dissimilar to inputs[3*i+2]):

for i in range(N // 3):
    inputs[3*i], inputs[3*i+1], inputs[3*i+2] = sample_triplet(i)

Question

Reformulating the shorter question to my particular case:

Given these two functions sample_triplet() and compute_label(), how can I build my input pipeline using Tensorflow Dataset, to build batches with N inputs and N // 3 labels?

I tried several combinations of tf.data.Dataset.from_generator() and tf.data.Dataset.flat_map() but couldn't find a way to both flatten the batch inputs from N // 3 triplets to N samples, and output only N // 3 batch labels.

A solution I found was to "cheat" by computing my labels inside tf.data.Dataset.from_generator() and tiling each label 3 times, to be able to use tf.data.Dataset.flat_map() over the triplet inputs + labels. As a batch-postprocessing step, I am then "squeezing" the N duplicated labels back to N // 3.

Example with Current Solution

import tensorflow as tf
import numpy as np

def sample_triplet():
    # Sampling our elements, here as [class, random_val] elements:
    anchor_class = puller_class = pusher_class = np.random.randint(0, 10)
    while pusher_class == anchor_class:
        # we want the pusher to be of a different class
        pusher_class = np.random.randint(0, 10) 

    anchor = np.array([anchor_class, np.random.randint(0, 5)])
    puller = np.array([puller_class, np.random.randint(0, 5)])
    pusher = np.array([pusher_class, np.random.randint(0, 5)])

    # Stacking the triplets, to then flat_map as a batch:
    triplet_inputs = np.stack((anchor, puller, pusher), axis=0)
    # Returning also the classes to compute the label afterwards:
    triplet_classes = np.stack((anchor_class, puller_class, pusher_class), axis=0)
    return triplet_inputs, triplet_classes

def compute_labels(triplet_classes):
    # Computing the label, e.g. distance between the anchor and pusher classes:
    label = np.abs(triplet_classes[0] - triplet_classes[2])
    return label

def triplet_generator():
    while True:
        triplet = sample_triplet()

        # Current solution: computing the label here too, 
        # stacking it 3 times so that flat_map works,
        # then afterwards removing the duplicates:
        triplet_inputs = triplet[0]
        triplet_label = compute_labels(triplet[1])
        yield triplet_inputs, 
              np.stack((triplet_label, triplet_label, triplet_label), axis=0)

def squeeze_triplet_labels(*batch):
    # Removing the duplicate labels,
    # going from a batch of (N inputs, N labels) to (N inputs, N // 3 labels)
    squeezed_labels = batch[-1][::3]
    new_batch = (*batch[:-1], squeezed_labels)
    return new_batch


batch_size = 30
assert(batch_size % 3 == 0)
sess = tf.InteractiveSession()
train_dataset = (tf.data.Dataset
                 .from_generator(triplet_generator, (tf.int32, tf.float32), ([3, 2], [3]))
                 .flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x))
                 .batch(batch_size))

next_training_batch = train_dataset.make_one_shot_iterator().get_next()
next_proper_training_batch = squeeze_triplet_labels(*next_training_batch)
batch = sess.run(next_proper_training_batch)
print("inputs shape: {} ; label shape: {}".format(batch[0].shape, batch[1].shape))
# >> inputs shape: (30, 2) ; label shape: (10,)

Upvotes: 2

Views: 708

Answers (1)

David Parks
David Parks

Reputation: 32071

One simple solution could be to create 2 Dataset objects, one for labels, one for data, then batch the data by groups of 3 and use tf.data.interleave to interleave the two datasets back together, producing the results you want.

If that's not easy to do then you could try the following process of mapping one element onto multiple. You would have to create a batch of 3 elements (with 3 labels), then split that in a map function into 3 sets of the data each against one of the labels you received. The recipe for doing that is in the following SO question, though it's a little more involved than the first suggestion:

In Tensorflow's Dataset API how do you map one element into multiple elements?

Upvotes: 2

Related Questions