Reputation: 6715
I have an autoencoder that takes an image as an input and produces a new image as an output.
The input image (1x1024x1024x3) is split into patches (1024x32x32x3) before being fed to the network.
Once I have the output, also a batch of patches size 1024x32x32x3, I want to be able to reconstruct a 1024x1024x3 image. I thought I had this sussed by simply reshaping, but here's what happened.
First, the image as read by Tensorflow:
I patched the image with the following code
patch_size = [1, 32, 32, 1]
patches = tf.extract_image_patches([image],
patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [1024, 32, 32, 3])
Here are a couple of patches from this image:
But it's when I reshape this patch data back into an image that things go pear-shaped.
reconstructed = tf.reshape(patches, [1, 1024, 1024, 3])
converted = tf.image.convert_image_dtype(reconstructed, tf.uint8)
encoded = tf.image.encode_png(converted)
In this example, no processing has been done between patching and reconstructing. I have made a version of the code you can use to test this behaviour. To use it, run the following:
echo "/path/to/test-image.png" > inputs.txt
mkdir images
python3 image_test.py inputs.txt images
The code will make one input image, one patch image, and one output image for each of the 1024 patches in each input image, so comment out the lines that create input and output images if you're only concerned with saving all the patches.
Somebody, please explain what happened :(
Upvotes: 18
Views: 17817
Reputation: 3009
tf.extract_image_patches
is quite difficult to use, as it does a lot of stuff in the background.
If you just need non-overlapping, then it's much easier to write it ourselves.
You can reconstruct the full image by inverting all operations in image_to_patches
.
Code sample (plots original image and patches):
import tensorflow as tf
from skimage import io
import matplotlib.pyplot as plt
def image_to_patches(image, patch_height, patch_width):
# resize image so that it's dimensions are dividable by patch_height and patch_width
image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32)
image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32)
height = tf.cast(tf.ceil(image_height / patch_height) * patch_height, dtype=tf.int32)
width = tf.cast(tf.ceil(image_width / patch_width) * patch_width, dtype=tf.int32)
num_rows = height // patch_height
num_cols = width // patch_width
# make zero-padding
image = tf.squeeze(tf.image.resize_image_with_crop_or_pad(image, height, width))
# get slices along the 0-th axis
image = tf.reshape(image, [num_rows, patch_height, width, -1])
# h/patch_h, w, patch_h, c
image = tf.transpose(image, [0, 2, 1, 3])
# get slices along the 1-st axis
# h/patch_h, w/patch_w, patch_w,patch_h, c
image = tf.reshape(image, [num_rows, num_cols, patch_width, patch_height, -1])
# num_patches, patch_w, patch_h, c
image = tf.reshape(image, [num_rows * num_cols, patch_width, patch_height, -1])
# num_patches, patch_h, patch_w, c
return tf.transpose(image, [0, 2, 1, 3])
image = io.imread('http://www.petful.com/wp-content/uploads/2011/09/slow-blinking-cat.jpg')
print('Original image shape:', image.shape)
tile_size = 200
image = tf.constant(image)
tiles = image_to_patches(image, tile_size, tile_size)
sess = tf.Session()
I, tiles = sess.run([image, tiles])
print(I.shape)
print(tiles.shape)
plt.figure(figsize=(1 * (4 + 1), 5))
plt.subplot(5, 1, 1)
plt.imshow(I)
plt.title('original')
plt.axis('off')
for i, tile in enumerate(tiles):
plt.subplot(5, 5, 5 + 1 + i)
plt.imshow(tile)
plt.title(str(i))
plt.axis('off')
plt.show()
Upvotes: 4
Reputation: 1802
Use Update#2 - One small example for your task: (TF 1.0)
Considering image of size (4,4,1) converted to patches of size (4,2,2,1) and reconstructed them back to image.
import tensorflow as tf
image = tf.constant([[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [15], [16]]])
patch_size = [1,2,2,1]
patches = tf.extract_image_patches([image],
patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [4, 2, 2, 1])
reconstructed = tf.reshape(patches, [1, 4, 4, 1])
rec_new = tf.space_to_depth(reconstructed,2)
rec_new = tf.reshape(rec_new,[4,4,1])
sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I)
print(I.shape)
print(P.shape)
print(R_n)
print(R_n.shape)
Output:
[[[ 1][ 2][ 3][ 4]]
[[ 5][ 6][ 7][ 8]]
[[ 9][10][11][12]]
[[13][14][15][16]]]
(4, 4, 1)
(4, 2, 2, 1)
[[[ 1][ 2][ 3][ 4]]
[[ 5][ 6][ 7][ 8]]
[[ 9][10][11][12]]
[[13][14][15][16]]]
(4,4,1)
#Update - for 3 channels (debugging..) working only for p = sqrt(h)
import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 32
image = tf.random_normal([h,h,c])
patch_size = [1,p,p,1]
patches = tf.extract_image_patches([image],
patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [h, p, p, c])
reconstructed = tf.reshape(patches, [1, h, h, c])
rec_new = tf.space_to_depth(reconstructed,p)
rec_new = tf.reshape(rec_new,[h,h,c])
sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)
Output :
(1024, 1024, 3)
(1024, 32, 32, 3)
(1024, 1024, 3)
0.0
#Update 2
Reconstructing from the output of extract_image_patches seems difficult. Used other functions to extract patches and reverse the process to reconstruct which seems easier.
import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 128
image = tf.random_normal([1,h,h,c])
# Image to Patches Conversion
pad = [[0,0],[0,0]]
patches = tf.space_to_batch_nd(image,[p,p],pad)
patches = tf.split(patches,p*p,0)
patches = tf.stack(patches,3)
patches = tf.reshape(patches,[(h/p)**2,p,p,c])
# Do processing on patches
# Using patches here to reconstruct
patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c])
patches_proc = tf.split(patches_proc,p*p,3)
patches_proc = tf.stack(patches_proc,axis=0)
patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c])
reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad)
sess = tf.Session()
I,P,R_n = sess.run([image,patches,reconstructed])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)
Output:
(1, 1024, 1024, 3)
(64, 128, 128, 3)
(1, 1024, 1024, 3)
0.0
You could see other cool tensor transformation functions here: https://www.tensorflow.org/api_guides/python/array_ops
Upvotes: 8
Reputation: 2163
Since I also struggled with this, I post a solution that might be useful to others. The trick is to realize that the inverse of tf.extract_image_patches
is its gradient, as suggested here. Since the gradient of this op is implemented in Tensorflow, it is easy to build the reconstruction function:
import tensorflow as tf
from keras import backend as K
import numpy as np
def extract_patches(x):
return tf.extract_image_patches(
x,
(1, 3, 3, 1),
(1, 1, 1, 1),
(1, 1, 1, 1),
padding="VALID"
)
def extract_patches_inverse(x, y):
_x = tf.zeros_like(x)
_y = extract_patches(_x)
grad = tf.gradients(_y, _x)[0]
# Divide by grad, to "average" together the overlapping patches
# otherwise they would simply sum up
return tf.gradients(_y, _x, grad_ys=y)[0] / grad
# Generate 10 fake images, last dimension can be different than 3
images = np.random.random((10, 28, 28, 3)).astype(np.float32)
# Extract patches
patches = extract_patches(images)
# Reconstruct image
# Notice that original images are only passed to infer the right shape
images_reconstructed = extract_patches_inverse(images, patches)
# Compare with original (evaluating tf.Tensor into a numpy array)
# Here using Keras session
images_r = images_reconstructed.eval(session=K.get_session())
print (np.sum(np.square(images - images_r)))
# 2.3820458e-11
Upvotes: 12
Reputation: 41
I don't know if the following code is an efficient implementation but it works!
_,n_row,n_col,n_channel = x.shape
n_patch = n_row*n_col // (patch_size**2) #assume square patch
patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])
rows = tf.split(patches,n_col//patch_size,axis=0)
rows = [tf.concat(tf.unstack(x),axis=1) for x in rows]
reconstructed = tf.concat(rows,axis=0)
Upvotes: 3
Reputation: 542
This code works for your specific case, as well as for cases when the images are square, with a square kernel and the image size is divisible by the kernel size.
I did not test it for other cases.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
size = 1024
k_size = 32
axes_1_2_size = int(np.sqrt((size * size) / (k_size * k_size)))
# Define a placeholder for image (or load it directly if you prefer)
img = tf.placeholder(tf.int32, shape=(1, size, size, 3))
# Extract patches
patches = tf.image.extract_image_patches(img, ksizes=[1, k_size, k_size, 1],
strides=[1, k_size, k_size, 1],
rates=[1, 1, 1, 1], padding='VALID')
# Reconstruct the image back from the patches
# First separate out the channel dimension
reconstruct = tf.reshape(patches, (1, axes_1_2_size, axes_1_2_size, k_size, k_size, 3))
# Tranpose the axes (I got this axes tuple for transpose via experimentation)
reconstruct = tf.transpose(reconstruct, (0, 1, 3, 2, 4, 5))
# Reshape back
reconstruct = tf.reshape(reconstruct, (size, size, 3))
im_arr = # load image with shape (size, size, 3)
# Run the operations
with tf.Session() as sess:
ps, r = sess.run([patches, reconstruct], feed_dict={img:[im_arr]})
# Plot the reconstructed image to verify
plt.imshow(r)
Upvotes: 2
Reputation: 1
Implemented using im2col
and col2im
:
import numpy as np
import keras
import tensorflow as tf
import matplotlib.pyplot as plt
class ImPatch():
def __init__(self):
pass
def save_image(self, img, N=None):
plt.imshow(img)
plt.savefig(str(N))
plt.clf()
def get_indices(self, X_shape, HF, WF, stride, pad):
# get input size
m, n_C, n_H, n_W = X_shape
# get output size
out_h = int((n_H + 2 * pad - HF) / stride) + 1
out_w = int((n_W + 2 * pad - WF) / stride) + 1
# ----Compute matrix of index i----
# Level 1 vector.
level1 = np.repeat(np.arange(HF), WF)
# Duplicate for the other channels.
level1 = np.tile(level1, n_C)
# Create a vector with an increase by 1 at each level.
everyLevels = stride * np.repeat(np.arange(out_h), out_w)
# Create matrix of index i at every levels for each channel.
i = level1.reshape(-1, 1) + everyLevels.reshape(1, -1)
# ----Compute matrix of index j----
# Slide 1 vector.
slide1 = np.tile(np.arange(WF), HF)
# Duplicate for the other channels.
slide1 = np.tile(slide1, n_C)
# Create a vector with an increase by 1 at each slide.
everySlides = stride * np.tile(np.arange(out_w), out_h)
# Create matrix of index j at every slides for each channel.
j = slide1.reshape(-1, 1) + everySlides.reshape(1, -1)
# ----Compute matrix of index d----
# This is to mark delimitation for each channel
# during multi-dimensional arrays indexing.
d = np.repeat(np.arange(n_C), HF * WF).reshape(-1, 1)
return i, j, d
def im2col(self, X, HF, WF, stride, pad):
# Padding
X_padded = np.pad(X, ((0,0), (0,0), (pad, pad), (pad, pad)), mode='constant')
i, j, d = self.get_indices(X.shape, HF, WF, stride, pad)
# Multi-dimensional arrays indexing.
cols = X_padded[:, d, i, j]
cols = np.concatenate(cols, axis=-1)
return cols
def col2im(self, col, X_shape, HF, WF, stride, pad):
# Get input size
N, D, H, W = X_shape
# Add padding if needed.
H_padded, W_padded = H + 2 * pad, W + 2 * pad
X_padded = np.zeros((N, D, H_padded, W_padded))
# Index matrices, necessary to transform our input image into a matrix.
i, j, d = self.get_indices(X_shape, HF, WF, stride, pad)
# Retrieve batch dimension by spliting dX_col N times: (X, Y) => (N, X, Y)
dX_col_reshaped = np.array(np.hsplit(col, N))
# Reshape our matrix back to image.
# slice(None) is used to produce the [::] effect which means "for every elements".
np.add.at(X_padded, (slice(None), d, i, j), dX_col_reshaped)
# Remove padding from new image if needed.
if pad == 0:
return X_padded
elif type(pad) is int:
return X_padded[pad:-pad, pad:-pad, :, :]
def get_patches(self, x, HF, WF, stride, verbose=False):
x_patches = tf.image.extract_patches(x, sizes=[1, HF, WF, 1], strides=[1, stride, stride, 1], rates=[1, 1, 1, 1], padding='VALID')
if verbose == True:
print (x_patches.shape, 'x_patches shape')
return x_patches
def get_img(self, x_patches, x_shape, HF, WF, stride, verbose=False):
x_patches_T = np.transpose(x_patches, (0, 3, 1, 2))
x_col = self.im2col(X=x_patches_T, HF=1, WF=1, stride=1, pad=0)
if verbose == True:
print (x_col.shape, 'x_col shape')
x_shape = (x_shape[0], x_shape[3], x_shape[1], x_shape[2])
x_reconstruct = self.col2im(col=x_col, X_shape=x_shape, HF=HF, WF=WF, stride=stride, pad=0)
x_reconstruct_T = np.transpose(x_reconstruct, (0, 2, 3, 1))
if verbose == True:
print (x_reconstruct.shape, 'x_reconstruct shape')
print (x_reconstruct_T.shape, 'x_reconstruct_T shape')
return x_reconstruct_T
def test(self, x, HF, WF, stride, save=True, verbose=True):
x_patches = self.get_patches(x, HF=HF, WF=WF, stride=stride, verbose=verbose)
x_reconstruct = self.get_img(x_patches, x_shape=x.shape, HF=HF, WF=WF, stride=stride, verbose=verbose)
if save == True:
idx = np.random.randint(0, x.shape[0])
self.save_image(img=x[idx].reshape(28, 28), N=0)
self.save_image(img=x_reconstruct[idx].reshape(28, 28), N=1)
return x_reconstruct
impatch = ImPatch()
(x, _), (_, _) = keras.datasets.mnist.load_data()
x = np.expand_dims(x[0:10], axis=-1)
HF, WF, stride = 4, 4, 4
impatch.test(x, HF=HF, WF=WF, stride=stride)
Upvotes: 0
Reputation: 11
I may be a bit late, but since I got it working with TF-2.3
, it might prove useful for others. The following code works for non-overlapping patches - single or multi-channel:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
class PatchesToImage(layers.Layer):
def __init__(self, imgh, imgw, imgc, patsz, is_squeeze=True, **kwargs):
super(PatchesToImage, self).__init__(**kwargs)
self.H = (imgh // patsz) * patsz
self.W = (imgw // patsz) * patsz
self.C = imgc
self.P = patsz
self.is_squeeze = is_squeeze
def call(self, inputs):
bs = tf.shape(inputs)[0]
rows, cols = self.H // self.P, self.W // self.P
patches = tf.reshape(inputs, [bs, rows, cols, -1, self.C])
pats_by_clist = tf.unstack(patches, axis=-1)
def tile_patches(ii):
pats = pats_by_clist[ii]
img = tf.nn.depth_to_space(pats, self.P)
return img
img = tf.map_fn(fn=tile_patches, elems=tf.range(self.C), fn_output_signature=inputs.dtype)
img = tf.squeeze(img, axis=-1)
img = tf.transpose(img, perm=[1,2,3,0])
C = tf.shape(img)[-1]
img = tf.cond(tf.logical_and(tf.constant(self.is_squeeze), C==1),
lambda: tf.squeeze(img, axis=-1), lambda: img)
return img
Upvotes: 1
Reputation: 447
Tf 2.0 users can use space_to_depth and depth_to_space if you aren't doing overlapping blocks.
Upvotes: 1
Reputation: 151
To specifically address the initial question, which is 'Reconstructing an image after using extract_image_patches', I propose using tf.scatter_nd()
and building a stratified image. This will work even in a situation where there is an overlap in the extracted patches or the image is under-sample. Here is my proposed solution.
import cv2
import numpy as np
import tensorflow as tf
# Function to extract patches using 'extract_image_patches'
def img_to_patches(raw_input, _patch_size=(128, 128), _stride=100):
with tf.variable_scope('im2_patches'):
patches = tf.image.extract_image_patches(
images=raw_input,
ksizes=[1, _patch_size[0], _patch_size[1], 1],
strides=[1, _stride, _stride, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)
h = tf.shape(patches)[1]
w = tf.shape(patches)[2]
patches = tf.reshape(patches, (patches.shape[0], -1, _patch_size[0], _patch_size[1], 3))
return patches, (h, w)
# Function to reconstruct image
def patches_to_img(update, _block_shape, _stride=100):
with tf.variable_scope('patches2im'):
_h = _block_shape[0]
_w = _block_shape[1]
bs = tf.shape(update)[0] # batch size
np = tf.shape(update)[1] # number of patches
ps_h = tf.shape(update)[2] # patch height
ps_w = tf.shape(update)[3] # patch width
col_ch = tf.shape(update)[4] # Colour channel count
wout = (_w - 1) * _stride + ps_w # Recalculate output shape of "extract_image_patches" including padded pixels
hout = (_h - 1) * _stride + ps_h # Recalculate output shape of "extract_image_patches" including padded pixels
x, y = tf.meshgrid(tf.range(ps_w), tf.range(ps_h))
x = tf.reshape(x, (1, 1, ps_h, ps_w, 1, 1))
y = tf.reshape(y, (1, 1, ps_h, ps_w, 1, 1))
xstart, ystart = tf.meshgrid(tf.range(0, (wout - ps_w) + 1, _stride),
tf.range(0, (hout - ps_h) + 1, _stride))
bb = tf.zeros((1, np, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(bs), (-1, 1, 1, 1, 1, 1)) # batch indices
yy = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + y + tf.reshape(ystart, (1, -1, 1, 1, 1, 1)) # y indices
xx = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + x + tf.reshape(xstart, (1, -1, 1, 1, 1, 1)) # x indices
cc = tf.zeros((bs, np, ps_h, ps_w, 1, 1), dtype=tf.int32) + tf.reshape(tf.range(col_ch), (1, 1, 1, 1, -1, 1)) # color indices
dd = tf.zeros((bs, 1, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(np), (1, -1, 1, 1, 1, 1)) # shift indices
idx = tf.concat([bb, yy, xx, cc, dd], -1)
stratified_img = tf.scatter_nd(idx, update, (bs, hout, wout, col_ch, np))
stratified_img = tf.transpose(stratified_img, (0, 4, 1, 2, 3))
stratified_img_count = tf.scatter_nd(idx, tf.ones_like(update), (bs, hout, wout, col_ch, np))
stratified_img_count = tf.transpose(stratified_img_count, (0, 4, 1, 2, 3))
with tf.variable_scope("consolidate"):
sum_stratified_img = tf.reduce_sum(stratified_img, axis=1)
stratified_img_count = tf.reduce_sum(stratified_img_count, axis=1)
reconstructed_img = tf.divide(sum_stratified_img, stratified_img_count)
return reconstructed_img, stratified_img
if __name__ == "__main__":
# load initial image
image_org = cv2.imread('orig_img.jpg')
# Add batch dimension
image = np.expand_dims(image_org, axis=0)
# set parameters
patch_size = (228, 228)
stride = 200
input_img = tf.placeholder(dtype=tf.float32, shape=image.shape, name="input_img")
# Extract patches using "extract_image_patches()"
extracted_patches, block_shape = img_to_patches(input_img, _patch_size=patch_size, _stride=stride)
# block_shape is the number of patches extracted in the x and in the y dimension
# extracted_patches.shape = (1, block_shape[0] * block_shape[1], patch_size[0], patch_size[1], 3)
reconstructed_img, stratified_img = patches_to_img(extracted_patches, block_shape, stride) # Reconstruct Image
with tf.Session() as sess:
ep, bs, ri, si = sess.run([extracted_patches, block_shape, reconstructed_img, stratified_img], feed_dict={input_img: image})
# print(bs)
si = si.astype(np.int32)
# Show reconstructed image
cv2.imshow('sd', ri[0, :, :, :].astype(np.float32) / 255)
cv2.waitKey(0)
# Show stratified images
for i in range(si.shape[1]):
im_1 = si[0, i, :, :, :]
cv2.imshow('sd', im_1.astype(np.float32)/255)
The above solution should work for batched images of arbirary color channel dimensions.
Upvotes: 0