Piotr Czapla
Piotr Czapla

Reputation: 26532

Is there a way to use tf.data.Dataset inside of another Dataset in Tensorflow?

I'm doing segmentation. Each training sample have multiple images with segmentation masks. I'm trying to write input_fn to merge all mask images in to one for each training sample. I was planning on using two Datasets, one that iterates over samples folders and another that reads all masks as one large batch and then merges them to one tensor.

I'm getting an error when nested make_one_shot_iterator is called. I Know that this approach is a bit of a stretch and mostlikely datasets wheren't designed for such usage. But then how should I approach this problem so that I avoid using tf.py_func?

Here is a simplified version of the dataset:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.
        list_files(sample_path+"/masks/*.png")
        .map(tf.read_file)
        .map(lambda x: tf.image.decode_image(x, channels=1))
        .batch(1024)) # maximum number of objects
    masks = masks_ds.make_one_shot_iterator().get_next()

    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds.map(read_sample)
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...

Upvotes: 4

Views: 2484

Answers (1)

mrry
mrry

Reputation: 126164

If the nested dataset has only a single element, you can use tf.contrib.data.get_single_element() on the nested dataset instead of creating an iterator:

def read_sample(sample_path):
    masks_ds = (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
                .map(tf.read_file)
                .map(lambda x: tf.image.decode_image(x, channels=1))
                .batch(1024)) # maximum number of objects
    masks = tf.contrib.data.get_single_element(masks_ds)
    return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.map(read_sample)
sample = ds.make_one_shot_iterator().get_next()

In addition, you can use the tf.data.Dataset.flat_map(), tf.data.Dataset.interleave(), or tf.contrib.data.parallel_interleave() transformationw to perform a nested Dataset computation inside a function and flatten the result into a single Dataset. For example, to get all of the samples in a single Dataset:

def read_all_samples(sample_path):
    return (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
            .map(tf.read_file)
            .map(lambda x: tf.image.decode_image(x, channels=1))
            .batch(1024)) # maximum number of objects

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.flat_map(read_all_samples)
sample = ds.make_one_shot_iterator().get_next()

Upvotes: 5

Related Questions