Reputation: 4597
To be clear, I am not
requires_grad = False
for that tensor).tensor.detach()
, see this question).I'm wondering how to forgo gradient computations for some elements of a loss tensor that give a NaN gradient every time -- essentially, to call .detach()
for individual elements of a tensor. The way to do this in Tensorflow is using tf.stop_gradients
, see this question.
Some context: My neural network computes a distance matrix of its predicted coordinates, as follows. The entries of the distance matrix D are given by d_ij = || coordinates_i - coordinates_j ||
. I want to backpropagate through the distance matrix creation step. However, the norm function includes a square root, which is not differentiable at 0 -- and the diagonal of the distance matrix is 0 by construction. Thus I get NaN gradients for the diagonal of the distance matrix. I would like to mask out the gradients on the diagonal of the distance matrix.
Minimal working example:
import torch
def compute_distance_matrix(coordinates):
L = len(coordinates)
gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
# gram_diag: L
diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
# diag_1: L x L
diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
# diag_2: L x L
distance_matrix = torch.sqrt(diag_1 + diag_2 - (2 * gram_matrix))
return distance_matrix
# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])
for i in range(500):
pred_distance_matrix = compute_distance_matrix(pred_coordinates)
true_distance_matrix = compute_distance_matrix(true_coordinates)
loss = obj(pred_distance_matrix, true_distance_matrix)
loss.backward()
print(loss.item())
optimizer.step()
gives
1.2868314981460571
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
...
Upvotes: 3
Views: 3011
Reputation: 4597
I initialized a new matrix and used a mask to copy the values with differentiable gradients from the previous tensor (in this case, the non-diagonal entries), then applied the not-everywhere-differentiable operation (the square root) to the new tensor. This allowed the gradient to only flow back through the entries that had a positive mask.
import torch
def compute_distance_matrix(coordinates):
# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = len(coordinates)
gram_matrix = torch.mm(coordinates, torch.transpose(coordinates, 0, 1))
gram_diag = torch.diagonal(gram_matrix, dim1=0, dim2=1)
# gram_diag: L
diag_1 = torch.matmul(gram_diag.unsqueeze(-1), torch.ones(1, L).to(coordinates.device))
# diag_1: L x L
diag_2 = torch.transpose(diag_1, dim0=0, dim1=1)
# diag_2: L x L
squared_distance_matrix = diag_1 + diag_2 - (2 * gram_matrix)
distance_matrix = torch.zeros_like(squared_distance_matrix)
mask = ~torch.eye(L, dtype=torch.bool).to(coordinates.device)
distance_matrix[mask] = torch.sqrt( squared_distance_matrix.masked_select(mask) )
return distance_matrix
# In reality, pred_coordinates is an output of the network, but we initialize it here for a minimal working example
L = 10
pred_coordinates = torch.randn(L, 3, requires_grad=True)
true_coordinates = torch.randn(L, 3, requires_grad=False)
obj = torch.nn.MSELoss()
optimizer = torch.optim.Adam([pred_coordinates])
for i in range(500):
pred_distance_matrix = compute_distance_matrix(pred_coordinates)
true_distance_matrix = compute_distance_matrix(true_coordinates)
loss = obj(pred_distance_matrix, true_distance_matrix)
loss.backward()
print(loss.item())
optimizer.step()
which gives:
1.222102403640747
1.2191187143325806
1.2162436246871948
1.2133947610855103
1.210543155670166
1.2076761722564697
1.204787015914917
1.2018715143203735
1.198927402496338
1.1959534883499146
1.1929489374160767
1.1899129152297974
1.1868458986282349
1.1837480068206787
1.180619239807129
1.1774601936340332
1.174271583557129
...
Upvotes: 3