Guillermina
Guillermina

Reputation: 3657

Turn CNN model into class

I am trying to build a CNN for multilabel classification in Pytorch (each image can have more than one label). So far I have built the model as follows:

model.fc = nn.Sequential(nn.Linear(2048, 512),
                                 nn.ReLU(),
                                 nn.Dropout(0.2),
                                 nn.Linear(512, 10),
                                 nn.LogSigmoid())
                                 # nn.LogSoftmax(dim=1))

criterion = nn.NLLLoss()
# criterion = nn.BCELoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.003)

But I would like to build it using a class like the following example:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(10, 20, 5)
        self.fc1 = nn.Linear(20 * 22 * 39, 100)
        self.fc2 = nn.Linear(100, 50)
        self.fc3 = nn.Linear(50, 10)
        self.fc4 = nn.Linear(10, 3)

    def forward(self, x):
        x = x.view(-1, 3, 100, 170)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 20 * 22 * 39)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return self.fc4(x)

What would be the best way of accomplishing this given that I am dealing with a multilabel classification problem? Any insights I would appreciate it.

Upvotes: 0

Views: 93

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24691

  1. You should use torch.nn.BCEWithLogitsLoss for multilabel classification (and numerical stability), no LogSigmoid or NLLLoss as the output.

  2. You have to output N elements for each element in batch, where 1 on position N in vector would mean an existence of class N on image.

  3. Your network is fine, provided you only got 3 labels to predict (either 0 or 1). You may think about it's design or use something pretrained, other than that it should at least run.

Upvotes: 1

Related Questions