Atheel Massalha
Atheel Massalha

Reputation: 464

Conv2d with out_channels=2 producing output with 1 channel

As I understand, out_channels should decide the number of channels in output (I am new to pytorch). I am running this code:

import torch
from torch import nn

img = torch.tensor([[[[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]],[[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]]])
torchDetector = nn.Conv2d(in_channels=2,out_channels=2, kernel_size=(2, 2),bias=False, stride=1, padding=0)
filter = torch.tensor([[[[1,1],[1,1]],[[-1,-1],[-1,-1]]]])
torchDetector.weight = nn.Parameter(filter, requires_grad=False)
torchDetected = torchDetector(img)

print(torchDetected)

and expecting a result with 2 channels, but getting 1 channel:

tensor([[[[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]]]])

I was expecting this result:

tensor([[[[ 8.,  8.,  8.],
          [ 8.,  8.,  8.],
          [ 8.,  8.,  8.]],

         [[-8., -8., -8.],
          [-8., -8., -8.],
          [-8., -8., -8.]]]])

What am I missing?

Upvotes: 1

Views: 356

Answers (1)

Alaa M.
Alaa M.

Reputation: 5273

You need to add a channel to the filter

import torch
import math
from torch import nn

img = torch.tensor([[[[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]],[[1,1,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]]])
print(img.shape)
print(img)
print()
torchDetector = nn.Conv2d(2,2, kernel_size=(2, 2),bias=False, stride=1, padding=0)
filter = torch.tensor([
                       [[[1, 1,],[1, 1]],
                        [[1, 1,],[1, 1]]],
                       [[[-1, -1,],[-1, -1]],
                        [[-1, -1,],[-1, -1]]]
                       ])
torchDetector.weight = nn.Parameter(filter, requires_grad=False)
torchDetected = torchDetector(img)

print(torchDetected.shape)
print(torchDetected)

Output:

torch.Size([1, 2, 4, 4])
tensor([[[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]]])

torch.Size([1, 2, 3, 3])
tensor([[[[ 8,  8,  8],
          [ 8,  8,  8],
          [ 8,  8,  8]],

         [[-8, -8, -8],
          [-8, -8, -8],
          [-8, -8, -8]]]])

Upvotes: 1

Related Questions