StanGeo
StanGeo

Reputation: 429

Correct Validation Loss in Pytorch?

I am a bit confused as to how to calculate Validation loss? Are validation loss to be computed at the end of an epoch OR should the loss be also monitored during iteration through the batches ? Below I have computed using running_loss which is getting accumulated over batches - but I want to see if its the correct approach?

def validate(loader, model, criterion):                       
    correct = 0                                               
    total = 0                                                 
    running_loss = 0.0                                        
    model.eval()                                              
    with torch.no_grad():                                     
        for i, data in enumerate(loader):                     
            inputs, labels = data                             
            inputs = inputs.to(device)                        
            labels = labels.to(device)                        
                                                              
            outputs = model(inputs)                           
            loss = criterion(outputs, labels)                 
            _, predicted = torch.max(outputs.data, 1)         
            total += labels.size(0)                           
            correct += (predicted == labels).sum().item()     
            running_loss = running_loss + loss.item()         
    mean_val_accuracy = (100 * correct / total)               
    mean_val_loss = ( running_loss )                  
    #mean_val_accuracy = accuracy(outputs,labels)             
    print('Validation Accuracy: %d %%' % (mean_val_accuracy)) 
    print('Validation Loss:'  ,mean_val_loss )                

Below is the training block I am using

def train(loader, model, criterion, optimizer, epoch):                                   
    correct = 0                                                                          
    running_loss = 0.0                                                                   
    i_max = 0                                                                            
    for i, data in enumerate(loader):                                                    
        total_loss = 0.0                                                                 
        #print('batch=',i)                                                               
        inputs, labels = data                                                            
        inputs = inputs.to(device)                                                       
        labels = labels.to(device)                                                       
                                                                                         
        optimizer.zero_grad()                                                            
        outputs = model(inputs)                                                          
        loss = criterion(outputs, labels)                                                
        loss.backward()                                                                  
        optimizer.step()                                                                 
                                                                                         
        running_loss += loss.item()                                                      
        if i % 2000 == 1999:                                                             
            print('[%d , %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))     
            running_loss = 0.0                                                           
                                                                                         
    print('finished training')
    return mean_val_loss, mean_val_accuracy

Upvotes: 4

Views: 6660

Answers (1)

Louis Lac
Louis Lac

Reputation: 6406

You can evaluate your network on the validation when you want. It can be every epoch or if this is too costly because the dataset is huge it can be each N epoch.

What you did seems correct, you compute the loss of the whole validation set. You can optionally divide by its length in order to normalize the loss, so the scale will be the same if you increase the validation set one day.

Upvotes: 6

Related Questions