Reputation: 7625
I want to constrain the parameters of an intermediate layer in a neural network to prefer discrete values: -1, 0, or 1. The idea is to add a custom objective function that would increase the loss if the parameters take any other value. Note that, I want to constrain parameters of a particular layer, not all layers.
How can I implement this in pytorch? I want to add this custom loss to the total loss in the training loop, something like this:
custom_loss = constrain_parameters_to_be_discrete
loss = other_loss + custom_loss
May be using a Dirichlet prior might help, any pointer to this?
Upvotes: 4
Views: 1815
Reputation: 114876
You can use the loss function:
def custom_loss_function(x):
loss = torch.abs(x**2 - torch.abs(x))
return loss.mean()
This graph plots the proposed loss for a single element:
As you can see, the proposed loss is zero for x={-1, 0, 1}
and positive otherwise.
Note that if you want to apply this loss to the weights of a specific layer, then your x
here are the weights, not the activations of the layer.
Upvotes: 1
Reputation: 24814
Extending upon @Shai answer and mixing it with this answer one could do it simpler via custom layer into which you could pass your specific layer.
First, the calculated derivative of torch.abs(x**2 - torch.abs(x))
taken from WolframAlpha
(check here) would be placed inside regularize
function.
Now the Constrainer
layer:
class Constrainer(torch.nn.Module):
def __init__(self, module, weight_decay=1.0):
super().__init__()
self.module = module
self.weight_decay = weight_decay
# Backward hook is registered on the specified module
self.hook = self.module.register_full_backward_hook(self._weight_decay_hook)
# Not working with grad accumulation, check original answer and pointers there
# If that's needed
def _weight_decay_hook(self, *_):
for parameter in self.module.parameters():
parameter.grad = self.regularize(parameter)
def regularize(self, parameter):
# Derivative of the regularization term created by @Shia
sgn = torch.sign(parameter)
return self.weight_decay * (
(sgn - 2 * parameter) * torch.sign(1 - parameter * sgn)
)
def forward(self, *args, **kwargs):
# Simply forward and args and kwargs to module
return self.module(*args, **kwargs)
Usage is really simple (with your specified weight_decay
hyperparameter if you need more/less force on the params):
constrained_layer = Constrainer(torch.nn.Linear(20, 10), weight_decay=0.1)
Now you don't have to worry about different loss functions and can use your model normally.
Upvotes: 1