CasellaJr
CasellaJr

Reputation: 458

How to use OneCycleLR?

I want to train on CIFAR-10, suppose for 200 epochs. This is my optimizer: optimizer = optim.Adam([x for x in model.parameters() if x.requires_grad], lr=0.001) I want to use OneCycleLR as scheduler. Now, according to the documentation, these are the parameters of OneCycleLR:

torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, total_steps=None, epochs=None, steps_per_epoch=None, pct_start=0.3, anneal_strategy='cos', cycle_momentum=True, base_momentum=0.85, max_momentum=0.95, div_factor=25.0, final_div_factor=10000.0, three_phase=False, last_epoch=- 1, verbose=False)

I have seen that the most used are max_lr, epochs and steps_per_epoch. The documentation says this:

About steps_per_epoch, I have seen in many github repo that it is used steps_per_epoch=len(data_loader), so if I have a batch size of 128, then this parameter it is equal to 128. However I do not understand what are the other 2 parameters. If I want to train for 200 epochs, then epochs=200? Or this is a parameter that runs the scheduler only for epoch and then it restarts? For example, If I write epochs=10 inside the scheduler, but I train in total for 200, it is like 20 complete steps of the scheduler? Then max_lr I have seen people using a value greater than the lr of the optimizer and other people using a smaller value. I think that max_lr must be greater than the lr (otherwise why it is called max :smiley: ?) However, if I print the learning rate epoch by epoch, it assumes strange values. For example, in this setting:

optimizer = optim.Adam([x for x in model.parameters() if x.requires_grad], lr=0.001)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.01, epochs=200, steps_per_epoch=128)

And this is the learning rate:

Epoch 1: TrL=1.7557, TrA=0.3846, VL=1.4136, VA=0.4917, TeL=1.4266, TeA=0.4852, LR=0.0004,
Epoch 2: TrL=1.3414, TrA=0.5123, VL=1.2347, VA=0.5615, TeL=1.2231, TeA=0.5614, LR=0.0004,
...
Epoch 118: TrL=0.0972, TrA=0.9655, VL=0.8445, VA=0.8161, TeL=0.8764, TeA=0.8081, LR=0.0005,
Epoch 119: TrL=0.0939, TrA=0.9677, VL=0.8443, VA=0.8166, TeL=0.9094, TeA=0.8128, LR=0.0005,

So lr is increasing

Upvotes: 4

Views: 12098

Answers (2)

Jacob Helwig
Jacob Helwig

Reputation: 61

Working version of the code generated in the above answer (edit queue is full)

import torch
import matplotlib.pyplot as plt
EPOCHS = 10
BATCHES = 10
steps = []
lrs = []
model = torch.nn.Linear(10,10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) # Wrapped optimizer
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=0.9,total_steps=EPOCHS * BATCHES)
for epoch in range(EPOCHS):
    for batch in range(BATCHES):
        optimizer.step()
        scheduler.step()
        lrs.append(scheduler.get_last_lr()[0])
        steps.append(epoch * BATCHES + batch)

plt.figure()
plt.legend()
plt.plot(steps, lrs, label='OneCycle')
plt.show()

Upvotes: 0

w568w
w568w

Reputation: 173

The documentation says that you should give total_steps or both epochs & steps_per_epoch as arguments. The simple relation between them is total_steps = epochs * steps_per_epoch.

And total_steps is the total number of steps in the cycle. OneCycle in the name means there is only one cycle through the training.

max_lr is the maximum learning rate of OneCycleLR. To be exact, the learning rate will increate from max_lr / div_factor to max_lr in the first pct_start * total_steps steps, and decrease smoothly to max_lr / final_div_factor then.


Edit: For those who are not familiar with lr_scheduler, you can plot the learning rate curve, e.g.

EPOCHS = 10
BATCHES = 10
steps = []
lrs = []
model = ... # Your model instance
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) # Wrapped optimizer
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=0.9,total_steps=EPOCHS * BATCHES)
for epoch in range(EPOCHS):
    for batch in range(BATCHES):
        scheduler.step()
        lrs.append(scheduler.get_last_lr()[0])
        steps.append(epoch * BATCHES + batch)

plt.figure()
plt.legend()
plt.plot(steps, lrs, label='OneCycle')
plt.show()

Upvotes: 6

Related Questions