Reputation: 1025
I want to implement a custom differentiable function in PyTorch that acts like torch.clamp in the forward pass but in the backward pass outputs the gradients as if it where a tanh.
I tried the following code:
import torch
class ClampWithGrad (torch.autograd.Function):
@staticmethod
def forward (ctx, input):
ctx.save_for_backward(input)
return torch.clamp(input, -1, 1)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input <= -1] = (1.0 - torch.tanh(input[input <= -1])**2.0) * grad_output[input <= -1]
grad_input[input >= 1] = (1.0 - torch.tanh(input[input >= 1])**2.0) * grad_output[input >= 1]
return grad_input
However, when I include this in my neural network, I get nans. How can I best implement this?
Upvotes: 0
Views: 656
Reputation: 1048
calculate the tanh once and store it in a variable to avoid computing it multiple times. also, clip the gradients to a maximum norm value of 1.0
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
tanh = torch.tanh(input)
grad_input[input <= -1] = (1.0 - tanh[input <= -1]**2.0) * grad_output[input <= -1]
grad_input[input >= 1] = (1.0 - tanh[input >= 1]**2.0) * grad_output[input >= 1]
max_norm = 1.0 # set the maximum gradient norm value
torch.nn.utils.clip_grad_norm_(grad_input, max_norm)
return grad_input
Upvotes: 1