Reputation: 1485
I have a 3D image with shape DxHxW
. I was successful to extract the image into patches pdxphxpw
(overlapping patches). For each patch, I do some processing. Now, I would like to generate the image from the processed patches such that the new image must be same shape with original image. Could you help me to do it.
This is my code to extract patch
def patch_extract_3D(input,patch_shape,xstep=1,ystep=1,zstep=1):
patches_3D = np.lib.stride_tricks.as_strided(input, ((input.shape[0] - patch_shape[0] + 1) / xstep, (input.shape[1] - patch_shape[1] + 1) / ystep,
(input.shape[2] - patch_shape[2] + 1) / zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(input.strides[0] * xstep, input.strides[1] * ystep,input.strides[2] * zstep, input.strides[0], input.strides[1],input.strides[2]))
patches_3D= patches_3D.reshape(patches_3D.shape[0]*patches_3D.shape[1]*patches_3D.shape[2], patch_shape[0],patch_shape[1],patch_shape[2])
return patches_3D
This is the processing the patches (just simple multiple with 2
for i in range(patches_3D.shape[0]):
patches_3D[i]=patches_3D[i];
patches_3D[i]=patches_3D[i]*2;
Now, what I need is from patches_3D, I want to reshape it to the original image. Thanks
This is example code
patch_shape=[2, 2, 2]
input=np.arange(4*4*6).reshape(4,4,6)
patches_3D=patch_extract_3D(input,patch_shape)
print patches_3D.shape
for i in range(patches_3D.shape[0]):
patches_3D[i]=patches_3D[i]*2
print patches_3D.shape
Upvotes: 1
Views: 1551
Reputation: 53079
This will do the reverse, however, since your patches overlap this will only be well-defined if their values agree where they overlap
def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
out = np.zeros(out_shape, patches.dtype)
patch_shape = patches.shape[-3:]
patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
(out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
patches_6D[...] = patches.reshape(patches_6D.shape)
return out
Update: here is a safer version that averages overlapping pixels:
def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
out = np.zeros(out_shape, patches.dtype)
denom = np.zeros(out_shape, patches.dtype)
patch_shape = patches.shape[-3:]
patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
(out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
denom_6D = np.lib.stride_tricks.as_strided(denom, ((denom.shape[0] - patch_shape[0] + 1) // xstep, (denom.shape[1] - patch_shape[1] + 1) // ystep,
(denom.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(denom.strides[0] * xstep, denom.strides[1] * ystep,denom.strides[2] * zstep, denom.strides[0], denom.strides[1],denom.strides[2]))
np.add.at(patches_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), patches.ravel())
np.add.at(denom_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), 1)
return out/denom
Upvotes: 5