racycle
racycle

Reputation: 13

Pytorch model training loss does not improve. Are the logistic regression model parameters/weights not updating?

I am attempting to use the following model to classify images. The training loss does not appear to converge/improve. Can you check the code and see if this may be a model issue implementing logistic regression?

The results I get for a series of 10 training epochs are:

epoch: 1, loss= -16.0369
epoch: 2, loss= -23.3950
epoch: 3, loss= -23.4226
epoch: 4, loss= -18.7254
epoch: 5, loss= -29.8720
epoch: 6, loss= -29.2601
epoch: 7, loss= -21.3710
epoch: 8, loss= -28.2535
epoch: 9, loss= -33.8465
epoch: 10, loss= -27.8332

Model code with optimizer:

class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = [] 
        self.linear.append(nn.Linear(in_features=28*28, out_features=1))
        self.linear = nn.Sequential(*self.linear)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        y = self.activation(torch.sigmoid(self.linear(x)))
        return y

Loss and optimizer:

learn_rate = 0.01
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(params = LR_model.parameters(), lr=learn_rate)

Data loader generates "images" and "labels" Training loop segment:

#forward 
        y_predicted = LR_model(images)
        total_loss = criterion(y_predicted, labels.unsqueeze(1))
        #backward
        total_loss.backward()
        #update
        optimizer.step()
        optimizer.zero_grad()
    # Print epoch result
    print(f'epoch: {epoch+1}, loss= {total_loss.item():.4f}')

Upvotes: 0

Views: 639

Answers (1)

Ivan
Ivan

Reputation: 40648

You shouldn't be using an ReLU activation in addition to your sigmoid activation. Instead return torch.sigmoid(self.linear(x)) straight away to nn.BCELoss.

However, you should use nn.BCELossWithLogits for numerical stability: it combines a sigmoid layer and the binary cross-entroy loss. In this case only output the logits i.e. self.linear(x).

Upvotes: 0

Related Questions