Reputation: 64
Sorry for the stupid question but, why torch.nnnConv2d() divides the image into 9 parts?
import torch
from torch import nn
import cv2
img = cv2.imread("image_game/eldenring 2022-12-14 19-29-50.png")
cv2.imshow('input', img)
size = img.shape # (720, 1280, 3)
img = img.reshape((1, img.shape[2], size[0], size[1]))
img = torch.tensor(img, dtype=torch.float32) # torch.Size([1, 3, 720, 1280])
c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=2, stride=1)
img = c1(img)
size = img.shape
img = img.reshape((size[2], size[3], size[1])).detach().numpy()
cv2.imshow('output', img)
cv2.waitKey(0)
return this:
I want this:
edit:
When I use
c1 = nn.Conv2d(1, 1, kernel_size=(3, 3), padding=2, stride=1)
instead
c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=2, stride=1)
I get what I want, but how to do it when there are more channels?
Upvotes: 0
Views: 183
Reputation: 64
I'm sorry that the description of the question was unclear. Javier TG solved my problem
The issue is with using reshape to permute the axes -> opencv's imread gives an array of size (H, W, 3), so to get the pytorch's (1, 3, H, W) representation, transpose (in numpy) and permute (in pytorch) should be used instead. Try substituting the first reshape with img = img[None].transpose(0, 3, 1, 2), and the last reshaping with img = img[0].permute(1, 2, 0).detach().numpy() – Javier TG
I thought the problem is in the nn.Conv2d()
function but i just wrong transposed data.
Corrected code:
import torch
from torch import nn
import cv2
img = cv2.imread("image_game/eldenring 2022-12-14 19-29-50.png")
cv2.imshow('input', img) # (720, 1280, 3)
img = img[None].transpose(0, 3, 1, 2)
img = torch.as_tensor(img).float() # torch.Size([1, 3, 720, 1280])
c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=1, stride=1)
img = c1(img)
img = img[0].permute(1, 2, 0).detach().numpy() # (720, 1280, 3)
cv2.imshow('output', img)
cv2.waitKey(0)
Upvotes: 1