Reputation: 502
I need to access image shapes to perform an augmentation pipeline although when accessing through image.shape[0] and image.shape[1]
I'm unable to perform the augmentations since it outputs that my tensors have shape None.
Related issues: How to access Tensor shape in .map?
Appreciate if anyone could help.
parsed_dataset = tf.data.TFRecordDataset(filenames=train_records_paths).map(parsing_fn) # Returns [image,label]
augmented_dataset = parsed_dataset.map(augment_pipeline)
augmented_dataset = augmented_dataset.unbatch()
"""
Returns:
5 Versions of the original image: 4 corner crops + a central crop and the respective labels.
"""
def augment_pipeline(original_image,label):
central_crop = lambda image: tf.image.central_crop(image,0.5)
corner_crops = lambda image: tf.image.extract_patches(images=tf.expand_dims(image,0), # Transform image in a batch of single sample
sizes=[1, int(0.5 * image.shape[0]), int(0.5 * image.shape[1]), 1], # 50% of the image's height and width
rates=[1, 1, 1, 1],
strides=[1, int(0.5 * image.shape[0]), int(0.5 * image.shape[1]), 1],
padding="SAME")
reshaped_patches = tf.reshape(corner_crops(original_image), [-1,int(0.5*original_image.shape[0]),int(0.5*original_image.shape[1]),3])
images = tf.concat([reshaped_patches,tf.expand_dims(central_crop(original_image),axis=0)],axis=0)
label = tf.reshape(label,[1,1])
labels = tf.tile(label,[5,1])
return images,labels
Upvotes: 0
Views: 1263
Reputation: 324
Every Dataset object is iterable. Now the Dataset object can either be in the batched form or the unbatched form. I will tell you how to get their elements shapes in both the cases.
Case 1. Dataset object is in unbatched form.
Method 1. Consuming its elements using iter
it = iter(dataset)
element = next(it)
image,label = element
## element is a tuple
Method 2. using take
element = dataset.take(1)
image,label = element
# element is a tuple
Case 2. When the dataset is batched. Now I assume that the dataset contains (image,label) tuples
Method 1. Using iter
it = iter(dataset)
batch = next(it)
images,labels = batch
## batch is a tuple check it using type(batch)
Method 2. Using take
batch = dataset.take(1)
## Note here each element of the dataset is a batch and each batch contains some number of
## (image,label) tuples
batch = next(iter(batch))
images,labels = batch
## batch is again a tuple
Upvotes: 0
Reputation: 502
After further research i was able to manage by using py_func
as suggested here and tf.shape(image)[0]
here.
"""
Returns:
5 Versions of the original image: 4 corner crops + a central crop and the respective labels.
"""
def augment_pipeline(original_image,label):
height = int(tf.shape(original_image)[0].numpy() * 0.5) # 50% of the image's height and width
width = int(tf.shape(original_image)[1].numpy() * 0.5)
central_crop = lambda image: tf.image.central_crop(image,0.5)
corner_crops = lambda image: tf.image.extract_patches(images=tf.expand_dims(image,0), # Transform image in a batch of single sample
sizes=[1, height, width, 1],
rates=[1, 1, 1, 1],
strides=[1, height, width, 1],
padding="SAME")
.
.
.
py_func
to allow accessing numpy values inside map function:parsed_dataset = tf.data.TFRecordDataset(filenames=train_records_paths).map(parsing_fn) # Returns [image,label]
augmented_dataset = parsed_dataset.map(lambda image,label: tf.py_function(func=augment_pipeline,
inp=[image,label],
Tout=[tf.float32,tf.int64]))
augmented_dataset = augmented_dataset.unbatch()
Upvotes: 2