Jacob Stern
Jacob Stern

Reputation: 4597

How to prevent gradient computations for certain elements of a tensor in Pytorch

To be clear, I am not

I'm wondering how to forgo gradient computations for some elements of a loss tensor that give a NaN gradient every time -- essentially, to call .detach() for individual elements of a tensor. The way to do this in Tensorflow is using tf.stop_gradients, see this question.

Some context: My neural network computes a distance matrix of its predicted coordinates, as follows. The entries of the distance matrix D are given by d_ij = || coordinates_i - coordinates_j ||. I want to backpropagate through the distance matrix creation step. However, the norm function includes a square root, which is not differentiable at 0 -- and the diagonal of the distance matrix is 0 by construction. Thus I get NaN gradients for the diagonal of the distance matrix. I would like to mask out the gradients on the diagonal of the distance matrix.

Minimal working example:

import torch

def compute_distance_matrix(coordinates):
    L = len(coordinates)
    gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
    gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
    # gram_diag: L
    diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
    # diag_1: L x L
    diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
    # diag_2: L x L
    distance_matrix = torch.sqrt(diag_1 + diag_2 - (2 * gram_matrix))
    return distance_matrix

# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])

for i in range(500):
    pred_distance_matrix = compute_distance_matrix(pred_coordinates)
    true_distance_matrix = compute_distance_matrix(true_coordinates)
    loss = obj(pred_distance_matrix, true_distance_matrix)
    loss.backward()
    print(loss.item())
    optimizer.step()

gives

1.2868314981460571
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
...

Upvotes: 3

Views: 3011

Answers (1)

Jacob Stern
Jacob Stern

Reputation: 4597

I initialized a new matrix and used a mask to copy the values with differentiable gradients from the previous tensor (in this case, the non-diagonal entries), then applied the not-everywhere-differentiable operation (the square root) to the new tensor. This allowed the gradient to only flow back through the entries that had a positive mask.

import torch

def compute_distance_matrix(coordinates):
    # In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
    L = len(coordinates)
    gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
    gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
    # gram_diag: L
    diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
    # diag_1: L x L
    diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
    # diag_2: L x L
    squared_distance_matrix = diag_1 + diag_2 - (2 * gram_matrix)
    distance_matrix = torch.zeros_like(squared_distance_matrix)
    mask = ~torch.eye(L, dtype=torch.bool).to(coordinates.device)
    distance_matrix[mask] = torch.sqrt( squared_distance_matrix.masked_select(mask) )
    return distance_matrix

# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])

for i in range(500):
    pred_distance_matrix = compute_distance_matrix(pred_coordinates)
    true_distance_matrix = compute_distance_matrix(true_coordinates)
    loss = obj(pred_distance_matrix, true_distance_matrix)
    loss.backward()
    print(loss.item())
    optimizer.step()

which gives:

1.222102403640747
1.2191187143325806
1.2162436246871948
1.2133947610855103
1.210543155670166
1.2076761722564697
1.204787015914917
1.2018715143203735
1.198927402496338
1.1959534883499146
1.1929489374160767
1.1899129152297974
1.1868458986282349
1.1837480068206787
1.180619239807129
1.1774601936340332
1.174271583557129
...

Upvotes: 3

Related Questions