Rakib
Rakib

Reputation: 7625

Constrain parameters to be -1, 0 or 1 in neural network in pytorch

I want to constrain the parameters of an intermediate layer in a neural network to prefer discrete values: -1, 0, or 1. The idea is to add a custom objective function that would increase the loss if the parameters take any other value. Note that, I want to constrain parameters of a particular layer, not all layers.

How can I implement this in pytorch? I want to add this custom loss to the total loss in the training loop, something like this:

custom_loss = constrain_parameters_to_be_discrete 
loss = other_loss + custom_loss 

May be using a Dirichlet prior might help, any pointer to this?

Upvotes: 4

Views: 1815

Answers (2)

Shai
Shai

Reputation: 114876

You can use the loss function:

def custom_loss_function(x):
  loss = torch.abs(x**2 - torch.abs(x))
  return loss.mean()

This graph plots the proposed loss for a single element:
enter image description here

As you can see, the proposed loss is zero for x={-1, 0, 1} and positive otherwise.

Note that if you want to apply this loss to the weights of a specific layer, then your x here are the weights, not the activations of the layer.

Upvotes: 1

Szymon Maszke
Szymon Maszke

Reputation: 24814

Extending upon @Shai answer and mixing it with this answer one could do it simpler via custom layer into which you could pass your specific layer.

First, the calculated derivative of torch.abs(x**2 - torch.abs(x)) taken from WolframAlpha (check here) would be placed inside regularize function.

Now the Constrainer layer:

class Constrainer(torch.nn.Module):
    def __init__(self, module, weight_decay=1.0):
        super().__init__()
        self.module = module
        self.weight_decay = weight_decay

        # Backward hook is registered on the specified module
        self.hook = self.module.register_full_backward_hook(self._weight_decay_hook)

    # Not working with grad accumulation, check original answer and pointers there
    # If that's needed
    def _weight_decay_hook(self, *_):
        for parameter in self.module.parameters():
            parameter.grad = self.regularize(parameter)

    def regularize(self, parameter):
        # Derivative of the regularization term created by @Shia
        sgn = torch.sign(parameter)
        return self.weight_decay * (
            (sgn - 2 * parameter) * torch.sign(1 - parameter * sgn)
        )

    def forward(self, *args, **kwargs):
        # Simply forward and args and kwargs to module
        return self.module(*args, **kwargs)

Usage is really simple (with your specified weight_decay hyperparameter if you need more/less force on the params):

constrained_layer = Constrainer(torch.nn.Linear(20, 10), weight_decay=0.1)

Now you don't have to worry about different loss functions and can use your model normally.

Upvotes: 1

Related Questions