Reputation: 1373
Following is a small working code for Stochastic Weight Averaging in Pytorch taken from here.
loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
for epoch in range(300):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
if epoch > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)
In this code after 160th epoch the swa_scheduler
is used instead of the usual scheduler
. What does swa_lr
signify? The documentation says,
Typically, in SWA the learning rate is set to a high constant value. SWALR is a learning rate scheduler that anneals the learning rate to a fixed value, and then keeps it constant.
optimizer
after 160th epoch?swa_lr
affect the optimizer
learning rate?Suppose that at the beginning of the code the optimizer
was ADAM
initialized with a learning rate of 1e-4
. Then does the above code imply that for the first 160 epochs the learning rate for training will be 1e-4
and then for the remaining number of epochs it will be swa_lr=0.05
? If yes, is it a good idea to define swa_lr
also to 1e-4
?
Upvotes: 5
Views: 3394
Reputation: 40638
does the above code imply that for the first 160 epochs the learning rate for training will be
1e-4
No it won't be equal to 1e-4
, during the first 160 epochs the learning rate is managed by the first scheduler scheduler
. This one is a initialize as a torch.optim.lr_scheduler.CosineAnnealingLR
. The learning rate will follow this curve:
for the remaining number of epochs it will be
swa_lr=0.05
This is partially true, during the second part - from epoch 160 - the optimizer's learning rate will be handled by the second scheduler swa_scheduler
. This one is initialized as a torch.optim.swa_utils.SWALR
. You can read on the documentation page:
SWALR is a learning rate scheduler that anneals the learning rate to a fixed value [
swa_lr
], and then keeps it constant.
By default (cf. source code), the number of epochs before annealing is equal to 10. Therefore the learning rate from epoch 170 to epoch 300 will be equal to swa_lr
and will stay this way. The second part will be:
This complete profile, i.e. both parts:
If yes, is it a good idea to define
swa_lr
also to1e-4
It is mentioned in the docs:
Typically, in SWA the learning rate is set to a high constant value.
Setting swa_lr
to 1e-4
would result in the following learning-rate profile:
Upvotes: 8