Tob
Tob

Reputation: 1025

Custom Pytorch function with hard clamp forwards and softclamp backwards

I want to implement a custom differentiable function in PyTorch that acts like torch.clamp in the forward pass but in the backward pass outputs the gradients as if it where a tanh.

I tried the following code:

import torch

class ClampWithGrad (torch.autograd.Function):
    
    @staticmethod
    def forward (ctx, input):
        ctx.save_for_backward(input)
        return torch.clamp(input, -1, 1)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input <= -1] = (1.0 - torch.tanh(input[input <= -1])**2.0) * grad_output[input <= -1]
        grad_input[input >= 1] = (1.0 - torch.tanh(input[input >= 1])**2.0) * grad_output[input >= 1]
        return grad_input

    

However, when I include this in my neural network, I get nans. How can I best implement this?

Upvotes: 0

Views: 656

Answers (1)

Phoenix
Phoenix

Reputation: 1048

calculate the tanh once and store it in a variable to avoid computing it multiple times. also, clip the gradients to a maximum norm value of 1.0

def backward(ctx, grad_output):
    input, = ctx.saved_tensors
    grad_input = grad_output.clone()
    tanh = torch.tanh(input)
    grad_input[input <= -1] = (1.0 - tanh[input <= -1]**2.0) * grad_output[input <= -1]
    grad_input[input >= 1] = (1.0 - tanh[input >= 1]**2.0) * grad_output[input >= 1]
    max_norm = 1.0  # set the maximum gradient norm value
    torch.nn.utils.clip_grad_norm_(grad_input, max_norm)
    return grad_input

Upvotes: 1

Related Questions