Reputation: 49329
I am loading my dataset as follows,
ds = tf.data.Dataset.list_files("/images/*.png")
train_size = int(0.8 * len(ds))
train_ds = ds.take(train_size)
train_ds = train_ds.map(load_sample)
//Splite each image into N smaller tiles
train_ds = train_ds.map(preprocessing_train, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.repeat()
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
Images are quite big and I need to take some measurements after inference so I would like to avoid resizing them. I have a seperate function that will take an image and split it into tiles eg. if an image is 512,512 I want 256,256 tiles it returns 2x2x256x256 I would like to train the network on this 256,256 tiles (during inference I will also run it on smaller tiles than combine to get the original picture). Using the Dataset
how can I split images into tiles after map(load_sample)
and before train_ds.map(preprocessing_train)
.
Upvotes: 1
Views: 522
Reputation: 11631
You can reshape your tiled tensor and call unbatch
to get rid of the extra dimension.
import tensorflow as tf
a = tf.expand_dims(tf.eye(6),0)
a = tf.concat([a,a,a,a], 0) # getting a dataset of 4 (6,6) "images"
ds = tf.data.Dataset.from_tensor_slices(a) # ds shape is (6,6)
# Mimicking your tiling op
tiling_op = lambda b :tf.reshape(b, (2,2,3,3))
ds = ds.map(tiling_op) # ds shape is now (2,2,3,3)
reshape_op = lambda b: tf.reshape(b, (-1,3,3))
ds = ds.map(reshape_op) # ds shape is now (4,3,3)
# getting rid of the tiled dimension
ds = ds.unbatch() # ds shape is now (3,3)
Upvotes: 2