mar_ey
mar_ey

Reputation: 77

Using tf extract_image_patches for input to a CNN?

I want to extract patches from my original images to use them as input for a CNN. After a little research I found a way to extract patches with tensorflow.compat.v1.extract_image_patches.

Since these need to be reshaped to "image format" I implemented a method reshape_image_patches to reshape them and store the reshaped patches in an array.

image_patches2 = []

def reshape_image_patches(image_patches, sess, ksize_rows, ksize_cols):
    a = sess.run(tf.shape(image_patches))
    nr, nc = a[1], a[2]
    for i in range(nr):
      for j in range(nc):
        patch = tf.reshape(image_patches[0,i,j,], [ksize_rows, ksize_cols, 3])
        image_patches2.append(patch)
    return image_patches2

How can I use this in combination with Keras generators to make these patches the input of my CNN?

Edit 1:

I have tried the approach in Load tensorflow images and create patches

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,
    label_mode=None,
    seed=1,
    subset='training',
    validation_split=0.1,
    image_size=(900, 900))

get_patches = lambda x: (tf.reshape(
    tf.image.extract_patches(
        x,
        sizes=[1, 16, 16, 1],
        strides=[1, 8, 8, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'),  (111*111, 16, 16, 3)))

dataset = dataset.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images = next(iter(dataset))
for index, image in enumerate(images):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()

In line: images = next(iter(dataset)) I get the error: InvalidArgumentError: Input to reshape is a tensor with 302800896 values, but the requested shape has 9462528 [[{{node Reshape}}]]

Does somebody know how to fix this?

Upvotes: 2

Views: 1049

Answers (1)

user11530462
user11530462

Reputation:

The tf.reshape does not change the order of or the total number of elements in the tensor. The error as states, you are trying to reduce total number of elements from 302800896 to 9462528 . You are using tf.reshape in lambda function.

In below example, I have recreated your scenario where I have the given the shape argument as 2 for tf.reshape which doesn't accommodate all the elements of original tensor, thus throws the error -

Code -

%tensorflow_version 2.x
import tensorflow as tf
t1 = tf.Variable([1,2,2,4,5,6])

t2 = tf.reshape(t1, 2)

Output -

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-3-0ff1d701ff22> in <module>()
      3 t1 = tf.Variable([1,2,2,4,5,6])
      4 
----> 5 t2 = tf.reshape(t1, 2)

3 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Input to reshape is a tensor with 6 values, but the requested shape has 2 [Op:Reshape]

tf.reshape should be in such a way that the arrangement of elements can change but total number of elements must remain the same. So the fix would be to change the shape to [2,3] -

Code -

%tensorflow_version 2.x
import tensorflow as tf
t1 = tf.Variable([1,2,2,4,5,6])

t2 = tf.reshape(t1, [2,3])
print(t2)

Output -

tf.Tensor(
[[1 2 2]
 [4 5 6]], shape=(2, 3), dtype=int32)

To solve your problem, either extract patches(tf.image.extract_patches) of size that you are trying to tf.reshape OR change the tf.reshape to size of extract patches.

Will also suggest you to look into other tf.image functionality like tf.image.central_crop and tf.image.crop_and_resize.

Upvotes: 2

Related Questions