Reputation: 335
Given a batch image tensor like B x C x W x H (batchSize,channels,width,height),
I would like to create a new tensor in which the new channels are the channels from nearby pixels (padded with 0s).
For instance, if I choose the nearby pixel size to be 3 x 3 (like a 3 x 3 filter) then there are 9 total nearby pixels and the final tensor size would be B x ( 9 * C ) x W x H.
Any recommendations on doing this, or do I just need to go the brute-force approach through iteration?
Upvotes: 1
Views: 310
Reputation: 1
I just wanted to do the same and I think it can be achieved by just a conv2d operation:
# Prepare the kernel
window_size = 3
windowmize_kernel = torch.zeros(window_size ** 2, tensor_to_repeat.shape[1], window_size, window_size)
for i in range(window_size ** 2):
windowmize_kernel[i, :, i//window_size, i%window_size] = 1
# Apply the kernel
repeated_tensor = torch.nn.functional.conv2d(tensor_to_repeat, windowmize_kernel, stride=1, padding="same")
Upvotes: 0
Reputation: 4826
For future readers, if you don't want to break the computation graph (using skimage) or want to use a more efficient implementation by not moving data from/to GPU, you probably want a native PyTorch solution instead.
This problem is very close to inverse PixelShuffle, and has a currently active feature request. The difference is that the poster wants to maintain image resolution while this solution does not.
I am copying the requester's initial code (which is pretty efficient) here:
out_channel = c*(r**2)
out_h = h//r
out_w = w//r
fm_view = fm.contiguous().view(b, c, out_h, r, out_w, r)
fm_prime = fm_view.permute(0,1,3,5,2,4).contiguous().view(b,out_channel, out_h, out_w)
Upvotes: 0
Reputation: 12407
If you want to cut the edges short (img
is your image tensor):
from skimage.util import view_as_windows
B,C,W,H = img.shape
img_ = view_as_windows(img,(1,1,3,3)).reshape(B,C,W-2,H-2,-1).transpose(0,1,4,2,3).reshape(B,C*9,W-2,H-2)
And if you want to pad them with 0 instead:
from skimage.util import view_as_windows
img = np.pad(img,((0,0),(0,0),(1,1),(1,1)))
B,C,W,H = img.shape
img_ = view_as_windows(img,(1,1,3,3)).reshape(B,C,W-2,H-2,-1).transpose(0,1,4,2,3).reshape(B,C*9,W-2,H-2)
Upvotes: 1