Mohit Lamba
Mohit Lamba

Reputation: 1373

Setting learning rate for Stochastic Weight Averaging in PyTorch

Following is a small working code for Stochastic Weight Averaging in Pytorch taken from here.

loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(300):
    for input, target in loader:
        optimizer.zero_grad()
        loss_fn(model(input), target).backward()
        optimizer.step()
        if epoch > swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()

    # Update bn statistics for the swa_model at the end
    torch.optim.swa_utils.update_bn(loader, swa_model)
    # Use swa_model to make predictions on test data
    preds = swa_model(test_input)

In this code after 160th epoch the swa_scheduler is used instead of the usual scheduler. What does swa_lr signify? The documentation says,

Typically, in SWA the learning rate is set to a high constant value. SWALR is a learning rate scheduler that anneals the learning rate to a fixed value, and then keeps it constant.

  1. So what happens to the learning rate of the optimizer after 160th epoch?
  2. Does swa_lr affect the optimizer learning rate?

Suppose that at the beginning of the code the optimizer was ADAM initialized with a learning rate of 1e-4. Then does the above code imply that for the first 160 epochs the learning rate for training will be 1e-4 and then for the remaining number of epochs it will be swa_lr=0.05? If yes, is it a good idea to define swa_lr also to 1e-4?

Upvotes: 5

Views: 3394

Answers (1)

Ivan
Ivan

Reputation: 40638

  • does the above code imply that for the first 160 epochs the learning rate for training will be 1e-4

    No it won't be equal to 1e-4, during the first 160 epochs the learning rate is managed by the first scheduler scheduler. This one is a initialize as a torch.optim.lr_scheduler.CosineAnnealingLR. The learning rate will follow this curve:

    enter image description here


  • for the remaining number of epochs it will be swa_lr=0.05

    This is partially true, during the second part - from epoch 160 - the optimizer's learning rate will be handled by the second scheduler swa_scheduler. This one is initialized as a torch.optim.swa_utils.SWALR. You can read on the documentation page:

    SWALR is a learning rate scheduler that anneals the learning rate to a fixed value [swa_lr], and then keeps it constant.

    By default (cf. source code), the number of epochs before annealing is equal to 10. Therefore the learning rate from epoch 170 to epoch 300 will be equal to swa_lr and will stay this way. The second part will be:

    enter image description here

    This complete profile, i.e. both parts:

    enter image description here


  • If yes, is it a good idea to define swa_lr also to 1e-4

    It is mentioned in the docs:

    Typically, in SWA the learning rate is set to a high constant value.

    Setting swa_lr to 1e-4 would result in the following learning-rate profile:

    enter image description here

Upvotes: 8

Related Questions