razvanc92
razvanc92

Reputation: 173

Different filters for 2d convolution

I’m having an input of shape (B(atch), F(features), N(odes), T(timestamps)). Right now if I apply a 2d convolution with a kernel of shape (1,2) I will have a total of (F_out, F_in, 1,2) weights to learn which is alright. I want to extend this so that for each Node in the input I have it’s own filter with shape (1,2). Does any of you have any idea where should I start from? So far I looped through all N and apply the filter to its respective input. Unfortunately this approach is very slow.

Upvotes: 1

Views: 417

Answers (1)

Shai
Shai

Reputation: 114866

You are looking for "grouped convolution".
The doc for nn.Conv2d regarding the groups parameter:

At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

In your case, you want groups= number of nodes.

It's not that simple in your case, because you want to "merge" features and nodes, and to have only 1d grouped convolution over the "feature"+"node" dimension.
Moreover, you need to permute between "node"s and "feature"s in order to group the features according to nodes.

b = 10;
inf = 8;
outf = 13;
n = 3;
t = 50;

x = torch.rand((b, inf, n, t))  # input tensor
gconv = nn.Conv1d(inf, outf, kernel_size=(2), groups=n) #grouped conv

x_ready = x.permute(0, 2, 1, 3).view(b, inf*n, t)  
y_grouped = gconv(x_ready)
# "fix" y
y = y_grouped.view(n, n, outf, t).permute(0, 2, 1, 3)  # now y is b-outf-n-t

Upvotes: 2

Related Questions