Reputation: 413
Hi I'm trying to make this model using pytorch.
Each input is consisted of 20 images of size 28 X 28, which is C1 ~ Cp in the image. Each image goes to CNN of same structure, but their outputs are concatenated eventually.
I'm currently struggling with feeding multiple inputs to each of its respective CNN model. Each model in the first box with three convolutional layers will look like this as a code, but I'm not quite sure how I can put 20 different input to separate models of same structure to eventually concatenate.
self.features = nn.Sequential(
nn.Conv2d(1,10, kernel_size = 3, padding = 1),
nn.ReLU(),
nn.Conv2d(10, 14, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(14, 18, kernel_size=3, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(28*28*18, 256)
)
I've tried out giving a list of inputs as an input to forward function, but it ended up with an error and won't go through. I'll be more than happy to explain further if anything is unclear.
Upvotes: 4
Views: 4178
Reputation: 1680
Assuming each path have it's own weights, may be this could be done with grouped convolution, although pre fusion Linear
can cause some trouble.
P = 20
self.features = nn.Sequential(
nn.Conv2d(1*P,10*P, kernel_size = 3, padding = 1, groups = P ),
nn.ReLU(),
nn.Conv2d(10*P, 14*P, kernel_size=3, padding=1, groups = P),
nn.ReLU(),
nn.Conv2d(14*P, 18*P, kernel_size=3, padding=1, groups = P),
nn.ReLU(),
nn.Conv2d(18*P, 256*P, kernel_size=28, groups = P), # not shure about this one
nn.Flatten(),
nn.Linear(256*P, 1024 )
)
Upvotes: 3
Reputation: 3453
Simply define forward
as taking a list of tensors as input, then process each input with the corresponding CNN (in the example snippet, CNNs share the same structure but don't share parameters, which is what I assume you need. You'll need to fill in the dots ...
according to your specifications.
class MyModel(torch.nn.Module):
def __init__(self, ...):
...
self.cnns = torch.nn.ModuleList([torch.nn.Sequential(...) for _ in range(20)])
def forward(xs: list[Tensor]):
return torch.cat([cnn(x) for x, cnn in zip(xs, self.cnns)], dim=...)
Upvotes: 4