Reputation: 1135
In the paper Attention is all you need, under section 5.3, the authors suggested to increase the learning rate linearly and then decrease proportionally to the inverse square root of steps.
How do we implement this in PyTorch with Adam optimizer? Preferably without additional packages.
Upvotes: 24
Views: 48570
Reputation: 6618
by considering below image from paper, I can formulate my formula as
self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
class AdamWarmup:
def __init__(self, model_size, warmup_steps, optimizer):
self.model_size = model_size
self.warmup_steps = warmup_steps
self.optimizer = optimizer
self.current_step = 0
self.lr = 0
def get_lr(self):
return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
def step(self):
# Increment the number of steps each time we call the step function
self.current_step += 1
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
# update the learning rate
self.lr = lr
self.optimizer.step()
d_model = 512 #transformer model dim
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
modified_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)
for i, (image, labels) in enumerate(train_loader):
...
modified_optimizer.optimizer.zero_grad()
loss.backward()
modified_optimizer.step()
...
Upvotes: 0
Reputation: 159
This is built off of Fang Wu's answer but is logarithmic and allows you to use a built-in function as the main scheduler. CosineAnnealingLR can be replaced with any scheduler you want.
train_scheduler = CosineAnnealingLR(optimizer, num_epochs)
def warmup(current_step: int):
return 1 / (10 ** (float(number_warmup_epochs - current_step)))
warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup)
scheduler = SequentialLR(optimizer, [warmup_scheduler, train_scheduler], [number_warmup_epochs])
Upvotes: 3
Reputation: 421
The NoamOpt of cause provides a path to implement the warmup learning rate as in https://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer. However, it is a little bit old and inconvenient. A smarter way to achieve that is to directly use the lambda learning rate scheduler supported by Pytorch.
That is, you first define a warmup function to adjust the learning rate automatically as:
def warmup(current_step: int):
if current_step < args.warmup_steps: # current_step / warmup_steps * base_lr
return float(current_step / args.warmup_steps)
else: # (num_training_steps - current_step) / (num_training_steps - warmup_steps) * base_lr
return max(0.0, float(args.training_steps - current_step) / float(max(1, args.training_steps - args.warmup_steps)))
Then you build the learning rate scheduler and use it during the training process:
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup)
Upvotes: 11
Reputation: 121
As suggested in the last comment, we can use the class introduced by https://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer. But this answer will give an error unless we define a function to update the state_dict.
So here's the full Scheduler:
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.model_size = model_size
self._rate = 0
def state_dict(self):
"""Returns the state of the warmup scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the warmup scheduler's state.
Arguments:
state_dict (dict): warmup scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return (self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))
Later, to use it inside the training loop:
optimizer = NoamOpt(input_opts['d_model'], 500,
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
. . .
optimizer.step()
Upvotes: 8
Reputation: 41
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * \
(self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))
def get_std_opt(model):
return NoamOpt(model.src_embed[0].d_model, 2, 4000,torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
As in: https://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer
Upvotes: 4
Reputation: 11628
PyTorch provides learning-rate-schedulers for implementing various methods of adjusting the learning rate during the training process. Some simple LR-schedulers are are already implemented and can be found here: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
In your special case you can - just like the other LR-schedulers do - subclass _LRScheduler
for implementing a variable schedule based on the number of epochs. For a bare-bones method you only need to implement __init__()
and get_lr()
methods.
Just note that many of these schedulers expect you to call .step()
once per epoch. But you can also update it more frequently or even pass a custom argument just like in the cosine-annealing LR-scheduler: https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#CosineAnnealingLR
Upvotes: 16