Reputation: 26532
I'm doing segmentation. Each training sample have multiple images with segmentation masks. I'm trying to write input_fn
to merge all mask images in to one for each training sample.
I was planning on using two Datasets
, one that iterates over samples folders and another that reads all masks as one large batch and then merges them to one tensor.
I'm getting an error when nested make_one_shot_iterator
is called. I Know that this approach is a bit of a stretch and mostlikely datasets wheren't designed for such usage. But then how should I approach this problem so that I avoid using tf.py_func?
Here is a simplified version of the dataset:
def read_sample(sample_path):
masks_ds = (tf.data.Dataset.
list_files(sample_path+"/masks/*.png")
.map(tf.read_file)
.map(lambda x: tf.image.decode_image(x, channels=1))
.batch(1024)) # maximum number of objects
masks = masks_ds.make_one_shot_iterator().get_next()
return tf.reduce_max(masks, axis=0)
ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds.map(read_sample)
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...
Upvotes: 4
Views: 2484
Reputation: 126164
If the nested dataset has only a single element, you can use tf.contrib.data.get_single_element()
on the nested dataset instead of creating an iterator:
def read_sample(sample_path):
masks_ds = (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
.map(tf.read_file)
.map(lambda x: tf.image.decode_image(x, channels=1))
.batch(1024)) # maximum number of objects
masks = tf.contrib.data.get_single_element(masks_ds)
return tf.reduce_max(masks, axis=0)
ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.map(read_sample)
sample = ds.make_one_shot_iterator().get_next()
In addition, you can use the tf.data.Dataset.flat_map()
, tf.data.Dataset.interleave()
, or tf.contrib.data.parallel_interleave()
transformationw to perform a nested Dataset
computation inside a function and flatten the result into a single Dataset
. For example, to get all of the samples in a single Dataset
:
def read_all_samples(sample_path):
return (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
.map(tf.read_file)
.map(lambda x: tf.image.decode_image(x, channels=1))
.batch(1024)) # maximum number of objects
ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.flat_map(read_all_samples)
sample = ds.make_one_shot_iterator().get_next()
Upvotes: 5