Reputation: 77
I want to extract patches from my original images to use them as input for a CNN.
After a little research I found a way to extract patches with
tensorflow.compat.v1.extract_image_patches
.
Since these need to be reshaped to "image format" I implemented a method reshape_image_patches to reshape them and store the reshaped patches in an array.
image_patches2 = []
def reshape_image_patches(image_patches, sess, ksize_rows, ksize_cols):
a = sess.run(tf.shape(image_patches))
nr, nc = a[1], a[2]
for i in range(nr):
for j in range(nc):
patch = tf.reshape(image_patches[0,i,j,], [ksize_rows, ksize_cols, 3])
image_patches2.append(patch)
return image_patches2
How can I use this in combination with Keras generators to make these patches the input of my CNN?
Edit 1:
I have tried the approach in Load tensorflow images and create patches
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
dataset = tf.keras.preprocessing.image_dataset_from_directory(
<directory>,
label_mode=None,
seed=1,
subset='training',
validation_split=0.1,
image_size=(900, 900))
get_patches = lambda x: (tf.reshape(
tf.image.extract_patches(
x,
sizes=[1, 16, 16, 1],
strides=[1, 8, 8, 1],
rates=[1, 1, 1, 1],
padding='VALID'), (111*111, 16, 16, 3)))
dataset = dataset.map(get_patches)
fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images = next(iter(dataset))
for index, image in enumerate(images):
ax = plt.subplot(2, 2, index + 1)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image)
plt.show()
In line: images = next(iter(dataset)) I get the error: InvalidArgumentError: Input to reshape is a tensor with 302800896 values, but the requested shape has 9462528 [[{{node Reshape}}]]
Does somebody know how to fix this?
Upvotes: 2
Views: 1049
Reputation:
The tf.reshape
does not change the order of or the total number of elements in the tensor. The error as states, you are trying to reduce total number of elements from 302800896 to 9462528 . You are using tf.reshape
in lambda
function.
In below example, I have recreated your scenario where I have the given the shape
argument as 2
for tf.reshape
which doesn't accommodate all the elements of original tensor, thus throws the error -
Code -
%tensorflow_version 2.x
import tensorflow as tf
t1 = tf.Variable([1,2,2,4,5,6])
t2 = tf.reshape(t1, 2)
Output -
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-3-0ff1d701ff22> in <module>()
3 t1 = tf.Variable([1,2,2,4,5,6])
4
----> 5 t2 = tf.reshape(t1, 2)
3 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: Input to reshape is a tensor with 6 values, but the requested shape has 2 [Op:Reshape]
tf.reshape
should be in such a way that the arrangement of elements can change but total number of elements must remain the same. So the fix would be to change the shape to [2,3]
-
Code -
%tensorflow_version 2.x
import tensorflow as tf
t1 = tf.Variable([1,2,2,4,5,6])
t2 = tf.reshape(t1, [2,3])
print(t2)
Output -
tf.Tensor(
[[1 2 2]
[4 5 6]], shape=(2, 3), dtype=int32)
To solve your problem, either extract patches(tf.image.extract_patches
) of size that you are trying to tf.reshape
OR change the tf.reshape
to size of extract patches.
Will also suggest you to look into other tf.image
functionality like tf.image.central_crop and tf.image.crop_and_resize.
Upvotes: 2