Jouline
Jouline

Reputation: 90

Validation losses increasing after a few epochs

I'm building a small CNN model to predict plant crop disease with the Plant Village Dataset. It consists of 39 classes of different species with and without diseases.

CNN model

class CropDetectCNN(nn.Module):
    # initialize the class and the parameters
    def __init__(self):
        super(CropDetectCNN, self).__init__()

        # convolutional layer 1 & max pool layer 1
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3),
            nn.MaxPool2d(kernel_size=2))

        # convolutional layer 2 & max pool layer 2
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2),
            nn.MaxPool2d(kernel_size=2))

        #Fully connected layer
        self.fc = nn.Linear(32*28*28, 39)
        
        

    # Feed forward the network
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out


model = CropDetectCNN()

Training

criterion = nn.CrossEntropyLoss()  # this include softmax + cross entropy loss
optimizer = torch.optim.Adam(model.parameters())

def batch_gd(model, criterion, train_loader, validation_loader, epochs):
    train_losses = np.zeros(epochs)
    test_losses = np.zeros(epochs)
    validation_losses = np.zeros(epochs)
    
    for e in range(epochs):
        t0 = datetime.now()
        train_loss = []
        model.train()
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()

            output = model(inputs)

            loss = criterion(output, targets)

            train_loss.append(loss.item())  # torch to numpy world

            loss.backward()
            optimizer.step()
            
        
        train_loss = np.mean(train_loss)

        validation_loss = []

        for inputs, targets in validation_loader:
            
            model.eval()

            inputs, targets = inputs.to(device), targets.to(device)

            output = model(inputs)

            loss = criterion(output, targets)

            validation_loss.append(loss.item())  # torch to numpy world
        
        
        
        validation_loss = np.mean(validation_loss)

        train_losses[e] = train_loss
        validation_losses[e] = validation_loss

        dt = datetime.now() - t0

        print(
            f"Epoch : {e+1}/{epochs} Train_loss: {train_loss:.3f} Validation_loss: {validation_loss:.3f} Duration: {dt}"
        )

    return train_losses, validation_losses

# Running the function
train_losses, validation_losses = batch_gd(
    model, criterion, train_loader, validation_loader, 5
)

# And theses are results:
Epoch : 1/5 Train_loss: 1.164 Validation_loss: 0.861 Duration: 0:10:59.968168
Epoch : 2/5 Train_loss: 0.515 Validation_loss: 0.816 Duration: 0:10:49.199842
Epoch : 3/5 Train_loss: 0.241 Validation_loss: 1.007 Duration: 0:09:56.334155
Epoch : 4/5 Train_loss: 0.156 Validation_loss: 1.147 Duration: 0:10:12.625819
Epoch : 5/5 Train_loss: 0.135 Validation_loss: 1.603 Duration: 0:09:56.746308

Isn't the validation loss suppose to decrease with epochs ? So why is it first decreasing and then increasing ?

How should I set the number of epochs, and why ?

Any help is really appreciated !

Upvotes: 1

Views: 1313

Answers (1)

Vladimir Kulyashov
Vladimir Kulyashov

Reputation: 406

You are facing the phenomenon of "overfitting" when your validation loss goes up after decreasing. You should stop training at that point and try to use some tricks to avoid overfitting.

Getting different predictions might happen when your gradients keep updating during inference so try explicitly "stop" them from updating with torch.no_grad()

Upvotes: 4

Related Questions