godot
godot

Reputation: 1570

conv2d function in pytorch

I'm trying to use the function torch.conv2d from Pytorch but can't get a result I understand...

Here is a simple example where the kernel (filt) is the same size as the input (im) to explain what I'm looking for.

import pytorch

filt = torch.rand(3, 3)
im = torch.rand(3, 3)

I want to compute a simple convolution with no padding, so the result should be a scalar (i.e. a 1x1 tensor).

I tried this with conv2d:

# I have to convert image and kernel to 4 dimensions tensors to use conv2d
im_torch = im.reshape((im_height, filt_height, 1, 1))
filt_torch = filt.reshape((filt_height, im_height, 1, 1))
out = torch.nn.functional.conv2d(im_torch, filt_torch, stride=1, padding=0)
print(out)

But the result is not what I expected:

tensor([[[[0.6067]], [[0.3564]], [[0.5397]]],
    [[[0.2557]], [[0.0493]], [[0.2562]]],
    [[[0.6067]], [[0.3564]], [[0.5397]]]])

To give an idea of what I'd like, I want to reproduce scipy convolve2d behavior:

import scipy.signal
out_scipy = scipy.signal.convolve2d(im.detach().numpy(), filt.detach().numpy(), 'valid')
print(out_scipy)

which prints:

array([[1.195723]], dtype=float32)

Upvotes: 5

Views: 8238

Answers (2)

Krueger
Krueger

Reputation: 1228

The tensor shape of your input and the filter should be:

(batch, dim_ch, width, height)

and NOT:

(width, height, 1, 1)

e.g.

import torch
import torch.nn.functional as F
x = torch.randn(1,1,4,4);
y = torch.randn(1,1,4,4);
z = F.conv2d(x,y);

Output shape of z:

torch.Size([1,1,1,1])

Upvotes: 8

godot
godot

Reputation: 1570

Ok, I didn't find the exact answer to my question (i.e. how to use conv2d) but I found another way to do it.

First of all, I learned that I'm looking for is called a valid cross-correlation and it is actually the operation implemented by the [Conv2d][1] class.

Hence my solution uses the Conv2d class instead of the conv2d function.

import pytorch

img = torch.rand(3, 3)

model = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 3), stride=1, padding=0, bias=False)

res = conv_mdl(img)
print(res.shape)

Which prints the scalar I wanted:

torch.Size([1, 1, 1, 1])

PS: I also checked that the result is the right one, not just the dimension.

Upvotes: 2

Related Questions