Florentin Hennecker
Florentin Hennecker

Reputation: 2164

tf.data: Parallelize loading step

I have a data input pipeline that has:

I've been trying to fit this into a tf.data pipeline, and I'm stuck on running the preprocessing for multiple datapoints in parallel. So far I've tried this:

Am I missing anything here? Am I forced to either modify my preprocessing so that it can run in a graph or is there a way to multiprocess it?

Our previous way of doing this was using keras.Sequence which worked well but there's just too many people pushing the upgrade to the tf.data API. (hell, even trying the keras.Sequence with tf 2.2 yields WARNING:tensorflow:multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.)

Note: I'm using tf 2.2rc3

Upvotes: 11

Views: 4768

Answers (2)

A. Cordier
A. Cordier

Reputation: 56

I came across the same problem and found a (relatively) easy solution.

It turns out that the proper way to do so is indeed to first create a tf.data.Dataset object using the from_generator(gen) method, before applying your custom python processing function (wrapped within a py_function) with the map method. As you mentioned, there is a trick to avoid serialization / deserialization of the input.

The trick is to use a generator which will only generates the indexes of your training set. Each called training index will be passed to the wrapped py_function, which can in return evaluate your original dataset at that index. You can then process your datapoint and return your processed data to the rest of your tf.data pipeline.

def func(i):
    i = i.numpy() # decoding from the EagerTensor object
    x, y = processing_function(training_set[i])
    return x, y # numpy arrays of types uint8, float32

z = list(range(len(training_set))) # the index generator

dataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)

dataset = dataset.map(lambda i: tf.py_function(func=func, inp=[i], 
                                               Tout=[tf.uint8, tf.float32]), 
                      num_parallel_calls=12)

dataset = dataset.batch(1)

Note that in practice, depending on the model you train your dataset on, you will probably need to apply another map to your dataset after the batch:

def _fixup_shape(x, y):
    x.set_shape([None, None, None, nb_channels])
    y.set_shape([None, nb_classes])
    return x, y
dataset = dataset.map(_fixup_shape)

This is a known issue which seems to be due to the incapacity of the from_generator method to infer the shape properly in some cases. Hence you need to pass the expected output shape explicitly. For more information:

Upvotes: 4

StefanMK
StefanMK

Reputation: 1303

You can try to add batch() before map() in your input pipeline.

It is usually meant to reduce the overhead of the map function call for small map function, see here: https://www.tensorflow.org/guide/data_performance#vectorizing_mapping

However you can also use it to get a batch of input to your map py_function and use python multiprocessing there to speed things up.

This way you can get around the GIL limitations which makes num_parallel_calls in tf.data.map() useless for py_function map functions.

Upvotes: 2

Related Questions