Toonia
Toonia

Reputation: 145

What exactly is meant by param_groups in pytorch?

I would like to update learning rates corresponding to each weight matrix and each bias in pytorch during training. The answers here and here and many other answers I found online talk about doing this using the model's param_groups which to the best of my knowledge applies learning rates in groups, not layer weight/bias specific. I also want to update the learning rates during training, not pre-setting them with torch.optim.

Any help is appreciated.

Upvotes: 7

Views: 10912

Answers (1)

jodag
jodag

Reputation: 22244

Updates to model parameters are handled by an optimizer in PyTorch. When you define the optimizer you have the option of partitioning the model parameters into different groups, called param groups. Each param group can have different optimizer settings. For example one group of parameters could have learning rate of 0.1 and another could have learning rate of 0.01.

To do what you're asking, you can just make every parameter belong to a different param group. You'll need some way to keep track of which param group corresponds to which parameter. Once you've defined the optimizer with different groups you can update the learning rate whenever you want, including at training time.

For example, say we have the following simple linear model

import torch
import torch.nn as nn
import torch.optim as optim


class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20) 
        self.layer2 = nn.Linear(20, 1)

    def forward(self, x): 
        return self.layer2(self.layer1(x))


model = LinearModel()

and suppose we want learning rates for each trainable parameter initialized according to the following:

learning_rates = { 
    'layer1.weight': 0.01,
    'layer1.bias': 0.1,
    'layer2.weight': 0.001,
    'layer2.bias': 1.0}

We can use this dictionary to define a different learning rate for each parameter when we initialize the optimizer.

# Build param_group where each group consists of a single parameter.
# `param_group_names` is created so we can keep track of which param_group
# corresponds to which parameter.
param_groups = []
param_group_names = []
for name, parameter in model.named_parameters():
    param_groups.append({'params': [parameter], 'lr': learning_rates[name]})
    param_group_names.append(name)

# optimizer requires default learning rate even if its overridden by all param groups
optimizer = optim.SGD(param_groups, lr=10)

Alternatively, we could omit the 'lr' entry and each param group would be initialized with the default learning rate (lr=10 in this case).

At training time if we wanted to update the learning rates we could do so by iterating over each of the optimizer.param_groups and updating the 'lr' entry for each of them. For example, in the following simplified training loop, we update the learning rates before each step.

for i in range(10):
    output = model(torch.zeros(1, 10))
    loss = output.sum()
    optimizer.zero_grad()
    loss.backward()

    # we can change the learning rate whenever we want for each param group
    print(f'step {i} learning rates')
    for name, param_group in zip(param_group_names, optimizer.param_groups):
        param_group['lr'] = learning_rates[name] / (i + 1)
        print(f'    {name}: {param_group["lr"]}')

    optimizer.step()

which prints

step 0 learning rates
    layer1.weight: 0.01
    layer1.bias: 0.1
    layer2.weight: 0.001
    layer2.bias: 1.0
step 1 learning rates
    layer1.weight: 0.005
    layer1.bias: 0.05
    layer2.weight: 0.0005
    layer2.bias: 0.5
step 2 learning rates
    layer1.weight: 0.0033333333333333335
    layer1.bias: 0.03333333333333333
    layer2.weight: 0.0003333333333333333
    layer2.bias: 0.3333333333333333
step 3 learning rates
    layer1.weight: 0.0025
    layer1.bias: 0.025
    layer2.weight: 0.00025
    layer2.bias: 0.25
step 4 learning rates
    layer1.weight: 0.002
    layer1.bias: 0.02
    layer2.weight: 0.0002
    layer2.bias: 0.2
step 5 learning rates
    layer1.weight: 0.0016666666666666668
    layer1.bias: 0.016666666666666666
    layer2.weight: 0.00016666666666666666
    layer2.bias: 0.16666666666666666
step 6 learning rates
    layer1.weight: 0.0014285714285714286
    layer1.bias: 0.014285714285714287
    layer2.weight: 0.00014285714285714287
    layer2.bias: 0.14285714285714285
step 7 learning rates
    layer1.weight: 0.00125
    layer1.bias: 0.0125
    layer2.weight: 0.000125
    layer2.bias: 0.125
step 8 learning rates
    layer1.weight: 0.0011111111111111111
    layer1.bias: 0.011111111111111112
    layer2.weight: 0.00011111111111111112
    layer2.bias: 0.1111111111111111
step 9 learning rates
    layer1.weight: 0.001
    layer1.bias: 0.01
    layer2.weight: 0.0001
    layer2.bias: 0.1

Upvotes: 17

Related Questions