Karol Szymczak
Karol Szymczak

Reputation: 64

Why torch.nn.Conv2d() divides the image into 9 parts?

Sorry for the stupid question but, why torch.nnnConv2d() divides the image into 9 parts?

import torch
from torch import nn
import cv2

img = cv2.imread("image_game/eldenring 2022-12-14 19-29-50.png")
cv2.imshow('input', img)
size = img.shape #  (720, 1280, 3)
img = img.reshape((1, img.shape[2], size[0], size[1]))
img = torch.tensor(img, dtype=torch.float32)  #  torch.Size([1, 3, 720, 1280])

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=2, stride=1)
img = c1(img)

size = img.shape
img = img.reshape((size[2], size[3], size[1])).detach().numpy()
cv2.imshow('output', img)
cv2.waitKey(0)

return this:

input image: input output image: output

I want this:

gif

enter image description here

edit:

When I use

c1 = nn.Conv2d(1, 1, kernel_size=(3, 3), padding=2, stride=1)

instead

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=2, stride=1)

I get what I want, but how to do it when there are more channels?

Upvotes: 0

Views: 183

Answers (1)

Karol Szymczak
Karol Szymczak

Reputation: 64

I'm sorry that the description of the question was unclear. Javier TG solved my problem

The issue is with using reshape to permute the axes -> opencv's imread gives an array of size (H, W, 3), so to get the pytorch's (1, 3, H, W) representation, transpose (in numpy) and permute (in pytorch) should be used instead. Try substituting the first reshape with img = img[None].transpose(0, 3, 1, 2), and the last reshaping with img = img[0].permute(1, 2, 0).detach().numpy() – Javier TG

I thought the problem is in the nn.Conv2d() function but i just wrong transposed data.

Corrected code:

import torch
from torch import nn
import cv2

img = cv2.imread("image_game/eldenring 2022-12-14 19-29-50.png")
cv2.imshow('input', img)  # (720, 1280, 3)
img = img[None].transpose(0, 3, 1, 2)
img = torch.as_tensor(img).float()  # torch.Size([1, 3, 720, 1280])

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=1, stride=1)
img = c1(img)

img = img[0].permute(1, 2, 0).detach().numpy()  # (720, 1280, 3)
cv2.imshow('output', img)
cv2.waitKey(0)

Upvotes: 1

Related Questions