Reputation: 49
I am trying to load image dataset on to GPU using tf.data API since they offer optimized perfomance. But unfortunately using tf.data.Dataset.map() function doesnt return a dataset compatible with model.fit() or model.fit_generator(). Assume the directory tree is same as that required for keras ImageDataGenerator.
files = [os.path.join(train_dir, file) for file in os.listdir(train_dir)]
val_files = [os.path.join(val_dir, file) for file in os.listdir(val_dir)]
def get_data_train(file_path: str) -> tuple:
mask_path = tf.strings.regex_replace(file_path, '.jpg$', '.png')
mask_path = tf.strings.regex_replace(mask_path, 'Images', 'Label', replace_global=False)
mask = tf.io.read_file(mask_path)
mask = tf.image.decode_png(mask, channels=1)
mask = tf.image.resize_image_with_pad(mask, target_height=544, target_width=960)
image = tf.io.read_file(file_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_image_with_pad(image, target_height=544, target_width=960)
if tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32) > 0.5:
image = tf.image.flip_left_right(image)
mask = tf.image.flip_left_right(mask)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
mask = tf.image.convert_image_dtype(mask, dtype=tf.float32)
return image / 255., mask[:, :, 0] / 255.
def get_data_validation(file_path: str) -> tuple:
mask_path = tf.strings.regex_replace(file_path, '.jpg$', '.png')
mask_path = tf.strings.regex_replace(mask_path, 'Images', 'Label', replace_global=False)
mask = tf.io.read_file(mask_path)
mask = tf.image.decode_png(mask, channels=1)
mask = tf.image.resize_image_with_pad(mask, target_height=544, target_width=960)
image = tf.io.read_file(file_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_image_with_pad(image, target_height=544, target_width=960)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
mask = tf.image.convert_image_dtype(mask, dtype=tf.float32)
return image / 255., mask[:, :, 0] / 255.
def configure_for_performance(dataset: tf.data.Dataset):
dataset = dataset.cache()
dataset = dataset.shuffle(buffer_size=8)
dataset = dataset.batch(args.batch_size)
dataset = dataset.prefetch(buffer_size=8)
return dataset
train_ds = tf.data.Dataset.from_tensor_slices(files)
train_ds = train_ds.map(lambda inputs: tf.py_func(get_data_train, [inputs], Tout=[tf.float32, tf.float32]))
val_ds = tf.data.Dataset.from_tensor_slices(val_files)
val_ds = val_ds.map(lambda inputs: tf.py_func(get_data_validation, [inputs], Tout=[tf.float32, tf.float32]))
val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))
train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)
i get this error when i use model.fit() function
val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1038, in map
return MapDataset(self, map_func)
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2611, in __init__
map_func, "Dataset.map()", input_dataset)
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1860, in __init__
self._function.add_to_graph(ops.get_default_graph())
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 479, in add_to_graph
self._create_definition_if_needed()
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 335, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 344, in _create_definition_if_needed_impl
self._capture_by_value, self._caller_device)
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 864, in func_graph_from_py_func
outputs = func(*func_graph.inputs)
File "/data2/AIShare/Tools/pytorch-env/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1794, in tf_data_structured_function_wrapper
ret = func(*nested_args)
TypeError: <lambda>() takes 1 positional argument but 2 were given
If i dont set the shape of
val_ds = val_ds.map(lambda x: x.set_shape([None, 544, 960, 3], [None, 544, 960]))
Then model.fit() complaints of tensors of unknown rank. As of my research tf.py_func() results in losing of shape data and hence set_shape is required.
I am trying to use tf.data API to load image files of Cityscapes Dataset.
Thank You
Upvotes: 1
Views: 439
Reputation: 116
tf.py_func will not allow you to be on your GPU, it is explained here on the documentation of tf.py_function tensorflow documentation
Maybe you should write a function for your map like
def fct_for_map(img):
#your code
return my_tensor
and after you try
train_ds = train_ds.map(fct_for_map)
I hope this would help you
Upvotes: 0