user3051460
user3051460

Reputation: 1485

How to recover 3D image from its patches in Python?

I have a 3D image with shape DxHxW. I was successful to extract the image into patches pdxphxpw(overlapping patches). For each patch, I do some processing. Now, I would like to generate the image from the processed patches such that the new image must be same shape with original image. Could you help me to do it.

enter image description here

This is my code to extract patch

def patch_extract_3D(input,patch_shape,xstep=1,ystep=1,zstep=1):
    patches_3D = np.lib.stride_tricks.as_strided(input, ((input.shape[0] - patch_shape[0] + 1) / xstep, (input.shape[1] - patch_shape[1] + 1) / ystep,
                                                  (input.shape[2] - patch_shape[2] + 1) / zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                  (input.strides[0] * xstep, input.strides[1] * ystep,input.strides[2] * zstep, input.strides[0], input.strides[1],input.strides[2]))
    patches_3D= patches_3D.reshape(patches_3D.shape[0]*patches_3D.shape[1]*patches_3D.shape[2], patch_shape[0],patch_shape[1],patch_shape[2])
    return patches_3D

This is the processing the patches (just simple multiple with 2

for i in range(patches_3D.shape[0]):
    patches_3D[i]=patches_3D[i];
    patches_3D[i]=patches_3D[i]*2;

Now, what I need is from patches_3D, I want to reshape it to the original image. Thanks

This is example code

patch_shape=[2, 2, 2]
input=np.arange(4*4*6).reshape(4,4,6)
patches_3D=patch_extract_3D(input,patch_shape)
print  patches_3D.shape
for i in range(patches_3D.shape[0]):
    patches_3D[i]=patches_3D[i]*2
print  patches_3D.shape

Upvotes: 1

Views: 1551

Answers (1)

Paul Panzer
Paul Panzer

Reputation: 53079

This will do the reverse, however, since your patches overlap this will only be well-defined if their values agree where they overlap

def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
    out = np.zeros(out_shape, patches.dtype)
    patch_shape = patches.shape[-3:]
    patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
                                                  (out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                  (out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
    patches_6D[...] = patches.reshape(patches_6D.shape)
    return out

Update: here is a safer version that averages overlapping pixels:

def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
    out = np.zeros(out_shape, patches.dtype)
    denom = np.zeros(out_shape, patches.dtype)
    patch_shape = patches.shape[-3:]
    patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
                                                  (out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                  (out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
    denom_6D = np.lib.stride_tricks.as_strided(denom, ((denom.shape[0] - patch_shape[0] + 1) // xstep, (denom.shape[1] - patch_shape[1] + 1) // ystep,
                                                  (denom.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                  (denom.strides[0] * xstep, denom.strides[1] * ystep,denom.strides[2] * zstep, denom.strides[0], denom.strides[1],denom.strides[2]))
    np.add.at(patches_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), patches.ravel())
    np.add.at(denom_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), 1)
    return out/denom

Upvotes: 5

Related Questions