Reputation: 11
I am trying to implement a learning technique from a paper. The relevant portion is: The SNN baseline used a sliding window of 50 consecutive data points, representing 200 ms of data (50-point window, single-point stride) in order to calculate the loss, to allow for more information for backpropagation and avoid dead neurons and vanishing gradients. The MSE loss was linearly weighted from 0 to 1 for the 50 points within the window. However, I am getting issues with the gradient calculation.
My attempt at implementing this is below: `
for epoch in range(100):
net.train()
best_loss = float('inf')
patience_counter = 0
best_state = torch.save(net.state_dict(), "cur_state.pth")
loss_window = deque(maxlen=50)
window_sum = 0.0
for i, data in enumerate(train_set_loader):
inputs, label = data
optimizer.zero_grad()
outputs, _ = net(inputs)
loss = criterion(outputs, label)
loss_window.append(loss)
window_sum = window_sum + loss.item()
# # If the window is full, calculate the weighted loss and perform backpropagation
if len(loss_window) == 50:
weighted_loss = sum(((j + 1) / 50) * loss_value for j, loss_value in enumerate(loss_window))
print(f'weighted_loss: {weighted_loss}, requires_grad: {weighted_loss.requires_grad}')
weighted_loss.backward(retain_graph=True)
#weighted_loss.backward()
optimizer.step()
#Print gradients for debugging
for name, param in net.named_parameters():
if param.grad is not None:
print(f'{name}: {param.grad.norm()}')
# Update the window sum by subtracting the oldest loss
window_sum -= loss_window[0].item()
#Remove oldest loss element
loss_window.popleft()
net.eval()
val_loss = evaluate_model(net, criterion, val_set_loader)
if val_loss < best_loss:
best_loss = val_loss
best_state = torch.save(net.state_dict(), "cur_state.pth")
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
print("Early Stopping")
print(f"Best loss: {best_loss}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.load_state_dict(torch.load("cur_state.pth", map_location=device))
print("Model state returned to best performing.")
break
If I try without the retain_graph=True line in my backwards pass, I get the error that I'm trying to backward through the graph a second time. This makes sense as it is trying to backwards pass 49 of the same 50 gradients from the first call. When I do have retain_graph=True, I get the error: "one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [50, 2]]". I believe this is an issue with the updating of loss_window, but I am not sure how to go about changing this. My question is somewhat similar to how to calculate loss over a number of images and then back propagate the average loss and update network weight but in that question he doesn't have overlap in his window which I believe avoids the problem.
Upvotes: 1
Views: 50
Reputation: 11
The problem can be solved by using this context manager that will automatically apply copy on write to avoid this type of error Automatic differentiation package - torch.autograd — PyTorch 2.3 documentation
This question was first answered here
Upvotes: 0