Wasi Ahmad
Wasi Ahmad

Reputation: 37691

L1/L2 regularization in PyTorch

How do I add L1/L2 regularization in PyTorch without manually computing it?

Upvotes: 131

Views: 268265

Answers (9)

Mateen Ulhaq
Mateen Ulhaq

Reputation: 27201

Proof that weight_decay for torch.optim.Adam is the L2 regularization coefficient

The L2 regularized loss is:

L = f(θ) + ½λ∑θ²

Then, the derivative (gradient) vector is:

𝜕L/𝜕θ = 𝜕f/𝜕θ + λθ

PyTorch's Adam implementation computes the gradient as:

g = 𝜕L/𝜕θ = 𝜕f/𝜕θ + λθ

...where λ = weight_decay.

Adam algorithm (as used by PyTorch)
PyTorch Adam formulas

Terminology

The later-published AdamW paper uses the term "decoupled weight decay" to refer to a different concept. (Green in the image below.) This concept is different from PyTorch's weight_decay. (Pink in the image below.)

AdamW algorithm (from paper)
AdamW Algorithm

Note that the original Adam paper does not explicitly mention L2 regularization as is included by PyTorch. Presumably, this is because L2 is easy enough to implement outside the optimizer:

parameters = [g["params"] for g in optimizer.param_groups]
l2 = sum(p.square().sum() for p in parameters)
loss = mse(...) + weight_decay * l2

Upvotes: 0

Szymon Maszke
Szymon Maszke

Reputation: 24681

Previous answers, while technically correct, are inefficient performance wise and are not too modular (hard to apply on a per-layer basis, as provided by, say, keras layers).

PyTorch L2 implementation

Why PyTorch implemented L2 inside torch.optim.Optimizer instances?

Let's take a look at torch.optim.SGD source code (currently as functional optimization procedure), especially this part:

for i, param in enumerate(params):
    d_p = d_p_list[i]
    # L2 weight decay specified HERE!
    if weight_decay != 0:
        d_p = d_p.add(param, alpha=weight_decay)
  • One can see, that d_p (derivative of parameter, gradient) is modified and re-assigned for faster computation (not saving the temporary variables)
  • It has O(N) complexity without any complicated math like pow
  • It does not involve autograd extending the graph without any need

Compare that to O(n) **2 operations, addition and also taking part in backpropagation.

Math

Let's see L2 equation with alpha regularization factor (same could be done for L1 ofc):

L2

If we take derivative of any loss with L2 regularization w.r.t. parameters w (it is independent of loss), we get:

L2 deriv

So it is simply an addition of alpha * weight for gradient of every weight! And this is exactly what PyTorch does above!

L1 Regularization layer

Using this (and some PyTorch magic), we can come up with quite generic L1 regularization layer, but let's look at first derivative of L1 first (sgn is signum function, returning 1 for positive input and -1 for negative, 0 for 0):

L1 derivative

Full code with WeightDecay interface located in torchlayers third party library providing stuff like regularizing only weights/biases/specifically named paramters (disclaimer: I'm the author), but the essence of the idea outlined below (see comments):

class L1(torch.nn.Module):
    def __init__(self, module, weight_decay):
        super().__init__()
        self.module = module
        self.weight_decay = weight_decay

        # Backward hook is registered on the specified module
        self.hook = self.module.register_full_backward_hook(self._weight_decay_hook)

    # Not dependent on backprop incoming values, placeholder
    def _weight_decay_hook(self, *_):
        for param in self.module.parameters():
            # If there is no gradient or it was zeroed out
            # Zeroed out using optimizer.zero_grad() usually
            # Turn on if needed with grad accumulation/more safer way
            # if param.grad is None or torch.all(param.grad == 0.0):

            # Apply regularization on it
            param.grad = self.regularize(param)

    def regularize(self, parameter):
        # L1 regularization formula
        return self.weight_decay * torch.sign(parameter.data)

    def forward(self, *args, **kwargs):
        # Simply forward and args and kwargs to module
        return self.module(*args, **kwargs)

Read more about hooks in this answer or respective PyTorch docs if needed.

And usage is also pretty simple (should work with gradient accumulation and and PyTorch layers):

layer = L1(torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)) 

Upvotes: 53

iacob
iacob

Reputation: 24171

L2 regularization out-of-the-box

Yes, pytorch optimizers have a parameter called weight_decay which corresponds to the L2 regularization factor:

sgd = torch.optim.SGD(model.parameters(), weight_decay=weight_decay)

L1 regularization implementation

There is no analogous argument for L1, however this is straightforward to implement manually:

loss = loss_fn(outputs, labels)
l1_lambda = 0.001
l1_norm = sum(torch.linalg.norm(p, 1) for p in model.parameters())

loss = loss + l1_lambda * l1_norm

The equivalent manual implementation of L2 would be:

l2_reg = sum(p.pow(2).sum() for p in model.parameters())

Source: Deep Learning with PyTorch (8.5.2)

Upvotes: 28

Albert
Albert

Reputation: 68110

To extend on the good answers: As it was said, L2 norm added to the loss is equivalent to weight decay iff you use plain SGD without momentum. Otherwise, e.g. with Adam, it is not exactly the same. The AdamW paper [1] pointed out that weight decay is actually more stable. That is why you should use weight decay, which is an option to the optimizer. And consider using AdamW instead of Adam.

Also note, you probably don't want weight decay on all parameters (model.parameters()), but only on a subset. See here for examples:

[1] Decoupled Weight Decay Regularization (AdamW), 2017

Upvotes: 2

oukohou
oukohou

Reputation: 401

for L1 regularization and include weight only:

l1_reg = torch.tensor(0., requires_grad=True)

for name, param in model.named_parameters():
    if 'weight' in name:
        l1_reg = l1_reg + torch.linalg.norm(param, 1)

total_loss = total_loss + 10e-4 * l1_reg

Upvotes: 18

Sherif Ali
Sherif Ali

Reputation: 417

For L2 regularization,

l2_lambda = 0.01
l2_reg = torch.tensor(0.)

for param in model.parameters():
    l2_reg += torch.norm(param)

loss += l2_lambda * l2_reg

References:

Upvotes: 30

Kashyap
Kashyap

Reputation: 6689

See the documentation. Add a weight_decay parameter to the optimizer for L2 regularization.

Upvotes: 89

devil in the detail
devil in the detail

Reputation: 3275

Use weight_decay > 0 for L2 regularization:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

Upvotes: 107

prosti
prosti

Reputation: 46291

Interesting torch.norm is slower on CPU and faster on GPU vs. direct approach.

import torch
x = torch.randn(1024,100)
y = torch.randn(1024,100)

%timeit torch.sqrt((x - y).pow(2).sum(1))
%timeit torch.norm(x - y, 2, 1)

Out:

1000 loops, best of 3: 910 µs per loop
1000 loops, best of 3: 1.76 ms per loop

On the other hand:

import torch
x = torch.randn(1024,100).cuda()
y = torch.randn(1024,100).cuda()

%timeit torch.sqrt((x - y).pow(2).sum(1))
%timeit torch.norm(x - y, 2, 1)

Out:

10000 loops, best of 3: 50 µs per loop
10000 loops, best of 3: 26 µs per loop

Upvotes: 6

Related Questions