spadel
spadel

Reputation: 1036

How to use k channels in CNN for k FC Layers

I have an encoder, which outputs a tensor with shape (bn, c * k, 32, 32). I now want produce k means with shape (bn, k, 1, 2). So the means are 2-dim coordinates. To do so, I want to use k FC Layers, while for each mean k_i I only want to use c channels.

So my idea is, that I reshape the encoder output out to a 5d tensor with shape (bn, k, c, 32, 32). Then I can use the flattened out[:, 0] ... out[:, k] as input for the k linear layers.

The trivial solution would be to define the linear layers manually:

self.fc0 = nn.Linear(c * 32 * 32, 2)
...
self.fck = nn.Linear(c * 32 * 32, 2)

Then I could define the forward pass for each mean as follows:

mean_0 = self.fc0(out[:, 0].reshape(bn, -1))
...
mean_k = self.fck(out[:, k].reshape(bn, -1))

Is there a more efficient way to do that?

Upvotes: 2

Views: 111

Answers (2)

Ivan
Ivan

Reputation: 40648

I believe you are looking for a grouped convolution. You can let axis=1 have k*c tensors, so the input shape is (bn, k*c, 32, 32). Then use a nn.Conv2d convolution layer with 2*k filters and set to receive k groups so it's not a fully connected channel-wise (only k groups of c maps: convolves c at a time.

>>> bn = 1; k = 5; c = 3
>>> x = torch.rand(bn, k*c, 32, 32)
>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32, groups=k)

>>> m(x).shape
torch.Size([4, 10, 1, 1])

Which you can then reshape to your liking.


In terms of number of parameters. A typical nn.Conv2d usage would be:

>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32)
>>> sum(layer.numel() for layer in m.parameters())
153610

Which is exactly c*k*2*k*32*32 weights, plus 2*k biases.

In your case, you would have

>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32, groups=k)
>>> sum(layer.numel() for layer in m.parameters())
30730

Which is exactly c*2*k*32*32 weights, plus 2*k biases. i.e. k times less than the previous layer. A given filter's has only c layers (not k*c) which means it will have an input with c channels (i.e. one of the k groups containing c maps)

Upvotes: 2

Girish Hegde
Girish Hegde

Reputation: 1515

You can do something like this using nn.ModuleList:

import torch
import torch.nn as nn
import torch.nn.functional as F

class fclist(nn.Module):
    def __init__(self, k):
        super().__init__()
        '''
        k: no. of clusters
        '''
        self.k = k
        '''
        .
        .
        .
        Other previous layers
        .
        .
        '''

        c = 1

        self.out_layers = nn.ModuleList()
        for i in range(k):
            self.out_layers.append(nn.Linear(c*32*32, 2))

    def forward(self, x):
        '''
        .
        .
        .
        pass throgh previous layers
        .
        .
        '''
        x = [layer(x) for layer in self.out_layers]
        return x

Sample output:

>>> net =  fclist(k=3)
>>> inp = torch.randn(1, 1*32*32)
>>> net(inp)
[tensor([[-0.7319, -0.2686]], grad_fn=<AddmmBackward>), tensor([[-0.6248,  0.9180]], grad_fn=<AddmmBackward>), tensor([[0.2532, 0.1387]], grad_fn=<AddmmBackward>)]

Upvotes: 1

Related Questions