Reputation: 1036
I have an encoder, which outputs a tensor with shape (bn, c * k, 32, 32)
. I now want produce k means with shape (bn, k, 1, 2)
. So the means are 2-dim coordinates. To do so, I want to use k FC Layers, while for each mean k_i I only want to use c channels.
So my idea is, that I reshape the encoder output out
to a 5d tensor with shape (bn, k, c, 32, 32)
. Then I can use the flattened out[:, 0]
... out[:, k]
as input for the k linear layers.
The trivial solution would be to define the linear layers manually:
self.fc0 = nn.Linear(c * 32 * 32, 2)
...
self.fck = nn.Linear(c * 32 * 32, 2)
Then I could define the forward pass for each mean as follows:
mean_0 = self.fc0(out[:, 0].reshape(bn, -1))
...
mean_k = self.fck(out[:, k].reshape(bn, -1))
Is there a more efficient way to do that?
Upvotes: 2
Views: 111
Reputation: 40648
I believe you are looking for a grouped convolution. You can let axis=1
have k*c
tensors, so the input shape is (bn, k*c, 32, 32)
. Then use a nn.Conv2d
convolution layer with 2*k
filters and set to receive k
groups so it's not a fully connected channel-wise (only k
groups of c
maps: convolves c
at a time.
>>> bn = 1; k = 5; c = 3
>>> x = torch.rand(bn, k*c, 32, 32)
>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32, groups=k)
>>> m(x).shape
torch.Size([4, 10, 1, 1])
Which you can then reshape to your liking.
In terms of number of parameters. A typical nn.Conv2d
usage would be:
>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32)
>>> sum(layer.numel() for layer in m.parameters())
153610
Which is exactly c*k*2*k*32*32
weights, plus 2*k
biases.
In your case, you would have
>>> m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32, groups=k)
>>> sum(layer.numel() for layer in m.parameters())
30730
Which is exactly c*2*k*32*32
weights, plus 2*k
biases. i.e. k
times less than the previous layer. A given filter's has only c
layers (not k*c
) which means it will have an input with c
channels (i.e. one of the k
groups containing c
maps)
Upvotes: 2
Reputation: 1515
You can do something like this using nn.ModuleList:
import torch
import torch.nn as nn
import torch.nn.functional as F
class fclist(nn.Module):
def __init__(self, k):
super().__init__()
'''
k: no. of clusters
'''
self.k = k
'''
.
.
.
Other previous layers
.
.
'''
c = 1
self.out_layers = nn.ModuleList()
for i in range(k):
self.out_layers.append(nn.Linear(c*32*32, 2))
def forward(self, x):
'''
.
.
.
pass throgh previous layers
.
.
'''
x = [layer(x) for layer in self.out_layers]
return x
Sample output:
>>> net = fclist(k=3)
>>> inp = torch.randn(1, 1*32*32)
>>> net(inp)
[tensor([[-0.7319, -0.2686]], grad_fn=<AddmmBackward>), tensor([[-0.6248, 0.9180]], grad_fn=<AddmmBackward>), tensor([[0.2532, 0.1387]], grad_fn=<AddmmBackward>)]
Upvotes: 1