Mark
Mark

Reputation: 29

Pytorch: Custom thresholding activation function - gradient

I created an activation function class Threshold that should operate on one-hot-encoded image tensors.

The function performs min-max feature scaling on each channel followed by thresholding.

class Threshold(nn.Module):
def __init__(self, threshold=.5):
    super().__init__()
    if threshold < 0.0 or threshold > 1.0:
        raise ValueError("Threshold value must be in [0,1]")
    else:
        self.threshold = threshold

def min_max_fscale(self, input):
    r"""
    applies min max feature scaling to input. Each channel is treated individually.
    input is assumed to be N x C x H x W (one-hot-encoded prediction)
    """
    for i in range(input.shape[0]):
        # N
        for j in range(input.shape[1]):
            # C
            min = torch.min(input[i][j])
            max = torch.max(input[i][j])
            input[i][j] = (input[i][j] - min) / (max - min)
    return input

def forward(self, input):
    assert (len(input.shape) == 4), f"input has wrong number of dims. Must have dim = 4 but has dim {input.shape}"

    input = self.min_max_fscale(input)
    return (input >= self.threshold) * 1.0

When I use the function I get the following error, since the gradients are not calculated automatically I assume.

Variable._execution_engine.run_backward(RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I already had a look at How to properly update the weights in PyTorch? but could not get a clue how to apply it to my case.

How is it possible to calculate the gradients for this function?

Thanks for your help.

Upvotes: 0

Views: 797

Answers (1)

Ivan
Ivan

Reputation: 40648

The issue is you are manipulating and overwriting elements, this time of operation can't be tracked by autograd. Instead, you should stick with built-in functions. You example is not that tricky to tackle: you are looking to retrieve the minimum and maximum values along input.shape[0] x input.shape[1]. Then you will scale your whole tensor in one go i.e. in vectorized form. No for loops involved!

One way to compute min/max along multiple axes is to flatten those:

>>> x_f = x.flatten(2)

Then, find the min-max on the flattened axis while retaining all shapes:

>>> x_min = x_f.min(axis=-1, keepdim=True).values
>>> x_max = x_f.max(axis=-1, keepdim=True).values

The resulting min_max_fscale function would look something like:

class Threshold(nn.Module):
    def min_max_fscale(self, x):
        r"""
        Applies min max feature scaling to input. Each channel is treated individually. 
        Input is assumed to be N x C x H x W (one-hot-encoded prediction)
        """
        x_f = x.flatten(2)
        x_min, x_max = x_f.min(-1, True).values, x_f.max(-1, True).values

        x_f = (x_f - x_min) / (x_max - x_min)
        return x_f.reshape_as(x)

Important note:

You would notice that you can now backpropagate on min_max_fscale... but not on forward. This is because you are applying a boolean condition which is not a differentiable operation.

Upvotes: 1

Related Questions