Reputation: 2164
I have a data input pipeline that has:
tf.Tensor
(dicts and whatnot)I've been trying to fit this into a tf.data
pipeline, and I'm stuck on running the preprocessing for multiple datapoints in parallel. So far I've tried this:
Dataset.from_generator(gen)
and do the preprocessing in the generator; this works but it processes each datapoint sequentially, no matter what arrangement of prefetch
and fake map
calls I patch on it. Is it impossible to prefetch in parallel?tf.py_function
so I could map
it in parallel over my Dataset, but
py_function
would be handed over to the (single-process) python interpreter, so I'd be stuck with the python GIL which would not help me muchinterleave
but haven't found any which does not have issues from the first two ideas.Am I missing anything here? Am I forced to either modify my preprocessing so that it can run in a graph or is there a way to multiprocess it?
Our previous way of doing this was using keras.Sequence which worked well but there's just too many people pushing the upgrade to the tf.data
API. (hell, even trying the keras.Sequence with tf 2.2 yields WARNING:tensorflow:multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.
)
Note: I'm using tf 2.2rc3
Upvotes: 11
Views: 4768
Reputation: 56
I came across the same problem and found a (relatively) easy solution.
It turns out that the proper way to do so is indeed to first create a tf.data.Dataset
object using the from_generator(gen)
method, before applying your custom python processing function (wrapped within a py_function
) with the map
method. As you mentioned, there is a trick to avoid serialization / deserialization of the input.
The trick is to use a generator which will only generates the indexes of your training set. Each called training index will be passed to the wrapped py_function, which can in return evaluate your original dataset at that index. You can then process your datapoint and return your processed data to the rest of your tf.data
pipeline.
def func(i):
i = i.numpy() # decoding from the EagerTensor object
x, y = processing_function(training_set[i])
return x, y # numpy arrays of types uint8, float32
z = list(range(len(training_set))) # the index generator
dataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)
dataset = dataset.map(lambda i: tf.py_function(func=func, inp=[i],
Tout=[tf.uint8, tf.float32]),
num_parallel_calls=12)
dataset = dataset.batch(1)
Note that in practice, depending on the model you train your dataset on, you will probably need to apply another map
to your dataset after the batch
:
def _fixup_shape(x, y):
x.set_shape([None, None, None, nb_channels])
y.set_shape([None, nb_classes])
return x, y
dataset = dataset.map(_fixup_shape)
This is a known issue which seems to be due to the incapacity of the from_generator
method to infer the shape properly in some cases. Hence you need to pass the expected output shape explicitly. For more information:
Upvotes: 4
Reputation: 1303
You can try to add batch()
before map()
in your input pipeline.
It is usually meant to reduce the overhead of the map function call for small map function, see here: https://www.tensorflow.org/guide/data_performance#vectorizing_mapping
However you can also use it to get a batch of input to your map py_function
and use python multiprocessing
there to speed things up.
This way you can get around the GIL limitations which makes num_parallel_calls
in tf.data.map()
useless for py_function
map functions.
Upvotes: 2