bodokaiser
bodokaiser

Reputation: 15752

Convert tf.extract_image_patches to batch shape

I batch together my data

batch_size = 50
min_after_dequeue = 100
capacity = min_after_dequeue + 3 * batch_size

mr_batch, us_batch = tf.train.shuffle_batch(
      [mr, us], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
mr_batch, us_batch

This gives me tensor shapes:

(<tf.Tensor 'shuffle_batch_2:0' shape=(50, 466, 394, 1) dtype=int16>,
 <tf.Tensor 'shuffle_batch_2:1' shape=(50, 366, 323, 1) dtype=uint8>)

Then I resize the image to have same resolution:

mr_batch = tf.image.resize_bilinear(mr_batch, [366, 323])
mr_batch, us_batch

Which gives me shapes:

(<tf.Tensor 'ResizeBilinear_13:0' shape=(50, 366, 323, 1) dtype=float32>,
 <tf.Tensor 'shuffle_batch_2:1' shape=(50, 366, 323, 1) dtype=uint8>)

Finally I extract image patches:

us_patch = tf.extract_image_patches(label, [1, 7, 7, 1], [1, 2, 2, 1], [1, 1, 1, 1], 'SAME')
mr_patch = tf.extract_image_patches(image, [1, 7, 7, 1], [1, 2, 2, 1], [1, 1, 1, 1], 'SAME')

us_patch, mr_patch

And have shape:

(<tf.Tensor 'ExtractImagePatches_8:0' shape=(50, 92, 81, 1225) dtype=uint8>,
 <tf.Tensor 'ExtractImagePatches_9:0' shape=(50, 92, 81, 1225) dtype=float32>)

I would now like to convert this shape to (50*1225, 92, 81) so I can feed it to my train step.

How is this tensor operation called?

Upvotes: 1

Views: 607

Answers (1)

Olivier Moindrot
Olivier Moindrot

Reputation: 28198

You can use tf.reshape with the special argument -1 to fill the remaining value:

tf.reshape(us_patch, [-1, 92, 81])

However, this can be dangerous because when you get the previous shapes wrong (for instance if us_patch has shape [50, 92, 81, 1000]), TensorFlow will not output an error and just reshape the whole thing to [50*1000, 92, 81].

Upvotes: 1

Related Questions