aiven
aiven

Reputation: 4313

Build tensorflow dataset iterator that produce batches with special structure

As I mentioned in the title I need batches with special structure:

1111
5555
2222

Each digit represent feature-vector. So there are N=4 vectors of each classes {1,2,5} (M=3) and batch size is NxM=12.

To accomplish this task I'm using Tensorflow Dataset API and tfrecords:

My concern is that I have hundreds (and maybe thousands in the feature) of classes and storing iterator for each class doesn't look good (from memory and performance perspective).

Is there a better way?

Upvotes: 1

Views: 1359

Answers (1)

javidcf
javidcf

Reputation: 59701

If you have the list of files ordered by class, you can interleave the datasets:

import tensorflow as tf

N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)

dataset = tf.data.Dataset.from_tensor_slices(record_files)
# Consider tf.contrib.data.parallel_interleave for parallelization
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=M, block_length=N)
# Consider passing num_parallel_calls or using tf.contrib.data.map_and_batch for performance
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)

EDIT:

If you need also shuffling you can add it in the interleaving step:

import tensorflow as tf

N = 4
record_files = ['class1.tfrecord', 'class5.tfrecord', 'class2.tfrecord']
M = len(record_files)
SHUFFLE_BUFFER_SIZE = 1000

dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
    lambda record_file: tf.data.TFRecordDataset(record_file).shuffle(SHUFFLE_BUFFER_SIZE),
    cycle_length=M, block_length=N)
dataset = dataset.map(parse_function)
dataset = dataset.batch(N * M)

NOTE: Both interleave and batch will produce "partial" outputs if there are no more remaining elements (see docs). So you would have to take special care if it is important for you that every batch has the same shape and structure. As for batching, you can use tf.contrib.data.batch_and_drop_remainder, but as far as I know there is not a similar alternative for interleaving, so you would either have to make sure that all of your files have the same number of examples or just add repeat to the interleaving transformation.

EDIT 2:

I got a proof of concept of something like what I think you want:

import tensorflow as tf

NUM_EXAMPLES = 12
NUM_CLASSES = 9
records = [[str(i)] * NUM_EXAMPLES for i in range(NUM_CLASSES)]
M = 3
N = 4

dataset = tf.data.Dataset.from_tensor_slices(records)
dataset = dataset.interleave(tf.data.Dataset.from_tensor_slices,
                             cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
    lambda data: tf.data.Dataset.from_tensor_slices(
        tf.split(tf.random_shuffle(
            tf.reshape(data, (NUM_CLASSES, N))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N,)))
batch = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    while True:
        try:
            b = sess.run(batch)
            print(b''.join(b).decode())
        except tf.errors.OutOfRangeError: break

Output:

888866663333
555544447777
222200001111
222288887777
666655553333
000044441111
888822225555
666600004444
777733331111

The equivalent with record files would be something like this (assuming records are one-dimensional vectors):

import tensorflow as tf

NUM_CLASSES = 9
record_files = ['class{}.tfrecord'.format(i) for i in range(NUM_CLASSES)]
M = 3
N = 4
SHUFFLE_BUFFER_SIZE = 1000

dataset = tf.data.Dataset.from_tensor_slices(record_files)
dataset = dataset.interleave(
    lambda file_name: tf.data.TFRecordDataset(file_name).shuffle(SHUFFLE_BUFFER_SIZE),
    cycle_length=NUM_CLASSES, block_length=N)
dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(NUM_CLASSES * N))
dataset = dataset.flat_map(
    lambda data: tf.data.Dataset.from_tensor_slices(
        tf.split(tf.random_shuffle(
            tf.reshape(data, (NUM_CLASSES, N, -1))), NUM_CLASSES // M)))
dataset = dataset.map(lambda data: tf.reshape(data, (M * N, -1)))

This works by reading N elements of every class each time and shuffling and splitting the resulting block. It assumes that the number of classes is divisible by M and that all the files have the same number of records.

Upvotes: 3

Related Questions