beardybear
beardybear

Reputation: 159

parallelising tf.data.Dataset.from_generator with TF2.1

They are already 2 posts about this topics, but they have not been updated for the recent TF2.1 release...

In brief, I've got a lot of tif images to read and parse with a specific pipeline.

import tensorflow as tf
import numpy as np

files = # a list of str
labels = # a list of int
n_unique_label = len(np.unique(labels))

gen = functools.partial(generator, file_list=files, label_list=labels, param1=x1, param2=x2)
dataset = tf.data.Dataset.from_generator(gen, output_types=(tf.float32, tf.int32))
dataset = dataset.map(lambda b, c: (b, tf.one_hot(c, depth=n_unique_label)))

This processing works well. Nevertheless, I need to parallelize the file parsing part, trying the following solution:

files = # a list of str
files = tensorflow.data.Dataset.from_tensor_slices(files)

def wrapper(file_path):
    parser = partial(tif_parser, param1=x1, param2=x2)
    return tf.py_function(parser, inp=[file_path], Tout=[tf.float32])

dataset = files.map(wrapper, num_parallel_calls=2)

The difference is that I parse one file at a time here with the parser function. However, but it does not work:

  File "loader.py", line 643, in tif_parser
    image = numpy.array(Image.open(file_path)).astype(float)

  File "python3.7/site-packages/PIL/Image.py", line 2815, in open
    fp = io.BytesIO(fp.read())

AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'read'


     [[{{node EagerPyFunc}}]] [Op:IteratorGetNextSync]

As far as I understand, the tif_parser function does not receive a string but an (unevaluated) tensor. At now, this function is fairly simple:

def tif_parser(file_path, param1=1, param2=2):
    image = numpy.array(Image.open(file_path)).astype(float)
    image /= 255.0

    return image

Upvotes: 0

Views: 1505

Answers (1)

beardybear
beardybear

Reputation: 159

Here is how I have proceeded

dataset = tf.data.Dataset.from_tensor_slices((files, labels))

def wrapper(file_path, label):
    import functools
    parser = functools.partial(tif_parser,  param1=x1, param2=x2)
    return tf.data.Dataset.from_generator(parser, (tf.float32, tf.int32), args=(file_path, label))

dataset = dataset.interleave(wrapper, cycle_length=tf.data.experimental.AUTOTUNE)

# The labels are converted to 1-hot vectors, could be integrated in tif_parser
dataset = dataset.map(lambda i, l: (i, tf.one_hot(l, depth=unique_label_count)))

dataset = dataset.shuffle(buffer_size=file_count, reshuffle_each_iteration=True)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=False)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

Concretely, I generate a data set every time the parser is called. The parser is run cycle_length time at each call, meaning that cycle_length images are read at once. This is suited to my specific case, because I cannot load all the images in memory. I am unsure whether the prefetch is used correctly or not here.

Upvotes: 3

Related Questions