Burger
Burger

Reputation: 413

pytorch multiple branches of a model

enter image description here

Hi I'm trying to make this model using pytorch.

Each input is consisted of 20 images of size 28 X 28, which is C1 ~ Cp in the image. Each image goes to CNN of same structure, but their outputs are concatenated eventually.

I'm currently struggling with feeding multiple inputs to each of its respective CNN model. Each model in the first box with three convolutional layers will look like this as a code, but I'm not quite sure how I can put 20 different input to separate models of same structure to eventually concatenate.

        self.features = nn.Sequential(
            nn.Conv2d(1,10, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(10, 14, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(14, 18, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(28*28*18, 256)
        )

I've tried out giving a list of inputs as an input to forward function, but it ended up with an error and won't go through. I'll be more than happy to explain further if anything is unclear.

Upvotes: 4

Views: 4178

Answers (2)

Alexey Birukov
Alexey Birukov

Reputation: 1680

Assuming each path have it's own weights, may be this could be done with grouped convolution, although pre fusion Linear can cause some trouble.

    P = 20
    self.features = nn.Sequential(
        nn.Conv2d(1*P,10*P, kernel_size = 3, padding = 1, groups = P ),
        nn.ReLU(),
        nn.Conv2d(10*P, 14*P, kernel_size=3, padding=1, groups = P),
        nn.ReLU(),
        nn.Conv2d(14*P, 18*P, kernel_size=3, padding=1, groups = P),
        nn.ReLU(),
        nn.Conv2d(18*P, 256*P, kernel_size=28,          groups = P),  # not shure about this one
        nn.Flatten(),
        nn.Linear(256*P, 1024 )
    )

Upvotes: 3

KonstantinosKokos
KonstantinosKokos

Reputation: 3453

Simply define forward as taking a list of tensors as input, then process each input with the corresponding CNN (in the example snippet, CNNs share the same structure but don't share parameters, which is what I assume you need. You'll need to fill in the dots ... according to your specifications.

class MyModel(torch.nn.Module):
   def __init__(self, ...):
       ...
       self.cnns = torch.nn.ModuleList([torch.nn.Sequential(...) for _ in range(20)])
   
   def forward(xs: list[Tensor]):
       return torch.cat([cnn(x) for x, cnn in zip(xs, self.cnns)], dim=...)

Upvotes: 4

Related Questions