SlinkyPlatypus
SlinkyPlatypus

Reputation: 11

How to calculate loss over a sliding window of samples and then backpropagate the weighted average loss

I am trying to implement a learning technique from a paper. The relevant portion is: The SNN baseline used a sliding window of 50 consecutive data points, representing 200 ms of data (50-point window, single-point stride) in order to calculate the loss, to allow for more information for backpropagation and avoid dead neurons and vanishing gradients. The MSE loss was linearly weighted from 0 to 1 for the 50 points within the window. However, I am getting issues with the gradient calculation.

My attempt at implementing this is below: `

for epoch in range(100):
    net.train()
    
    best_loss = float('inf')
    patience_counter = 0
    best_state = torch.save(net.state_dict(), "cur_state.pth")
    
    loss_window = deque(maxlen=50)
    window_sum = 0.0
    
    for i, data in enumerate(train_set_loader):
        inputs, label = data
        optimizer.zero_grad()
        outputs, _ = net(inputs)
        
        loss = criterion(outputs, label)
        
        loss_window.append(loss)
        
        window_sum = window_sum + loss.item()

        # # If the window is full, calculate the weighted loss and perform backpropagation
        if len(loss_window) == 50:
            weighted_loss = sum(((j + 1) / 50) * loss_value for j, loss_value in  enumerate(loss_window))
       
            print(f'weighted_loss: {weighted_loss}, requires_grad: {weighted_loss.requires_grad}')
            
            weighted_loss.backward(retain_graph=True)
            #weighted_loss.backward()
            optimizer.step()
            
            #Print gradients for debugging
            for name, param in net.named_parameters():
                if param.grad is not None:
                    print(f'{name}: {param.grad.norm()}')
            
            
            # Update the window sum by subtracting the oldest loss
            window_sum -= loss_window[0].item()

            #Remove oldest loss element
            loss_window.popleft()
  
        


    net.eval()
    val_loss = evaluate_model(net, criterion, val_set_loader)
    
    
    if val_loss < best_loss:
        best_loss = val_loss
        best_state = torch.save(net.state_dict(), "cur_state.pth")
        patience_counter = 0
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print("Early Stopping")
        print(f"Best loss: {best_loss}")
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        net.load_state_dict(torch.load("cur_state.pth", map_location=device))
        print("Model state returned to best performing.")
        break

If I try without the retain_graph=True line in my backwards pass, I get the error that I'm trying to backward through the graph a second time. This makes sense as it is trying to backwards pass 49 of the same 50 gradients from the first call. When I do have retain_graph=True, I get the error: "one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [50, 2]]". I believe this is an issue with the updating of loss_window, but I am not sure how to go about changing this. My question is somewhat similar to how to calculate loss over a number of images and then back propagate the average loss and update network weight but in that question he doesn't have overlap in his window which I believe avoids the problem.

Upvotes: 1

Views: 50

Answers (1)

SlinkyPlatypus
SlinkyPlatypus

Reputation: 11

The problem can be solved by using this context manager that will automatically apply copy on write to avoid this type of error Automatic differentiation package - torch.autograd — PyTorch 2.3 documentation

This question was first answered here

Upvotes: 0

Related Questions