Reputation: 22694
I’m trying to implement Adam by myself for a learning purpose.
Here is my Adam implementation:
class ADAMOptimizer(Optimizer):
"""
implements ADAM Algorithm, as a preceding step.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(ADAMOptimizer, self).__init__(params, defaults)
def step(self):
"""
Performs a single optimization step.
"""
loss = None
for group in self.param_groups:
#print(group.keys())
#print (self.param_groups[0]['params'][0].size()), First param (W) size: torch.Size([10, 784])
#print (self.param_groups[0]['params'][1].size()), Second param(b) size: torch.Size([10])
for p in group['params']:
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Momentum (Exponential MA of gradients)
state['exp_avg'] = torch.zeros_like(p.data)
#print(p.data.size())
# RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
b1, b2 = group['betas']
state['step'] += 1
# L2 penalty. Gotta add to Gradient as well.
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Momentum
exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
# RMS
exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)
denom = exp_avg_sq.sqrt() + group['eps']
bias_correction1 = 1 / (1 - b1 ** state['step'])
bias_correction2 = 1 / (1 - b2 ** state['step'])
adapted_learning_rate = group['lr'] * bias_correction1 / math.sqrt(bias_correction2)
p.data = p.data - adapted_learning_rate * exp_avg / denom
if state['step'] % 10000 ==0:
print ("group:", group)
print("p: ",p)
print("p.data: ", p.data) # W = p.data
return loss
I think I implemented everything correct however the loss graph of my implementation is very spiky compared to that of torch.optim.Adam.
My ADAM implementation loss graph (below)
torch.optim.Adam loss graph (below)
If someone could tell me what I am doing wrong, I’ll be very grateful.
For the full code including data, graph (super easy to run): https://github.com/aerinkim/AMS_pytorch/blob/master/AdamFails_1dConvex.ipynb
Upvotes: 2
Views: 3156
Reputation: 1
It looks as if the above code is not storing the updated state for exp_avg
and exp_avg_sq
. Another minor detail is that the bias correction for the denominator is also applied to the epsilon. Finally, Adam uses beta2=0.999
by default.
The following version with minimal changes should do the trick:
class ADAMOptimizer(torch.optim.Optimizer):
"""
implements ADAM Algorithm, as a preceding step.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(ADAMOptimizer, self).__init__(params, defaults)
def step(self):
"""
Perform a single optimization step.
"""
loss = None
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Momentum (Exponential MA of gradients)
state['exp_avg'] = torch.zeros_like(p.data)
# RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
b1, b2 = group['betas']
state['step'] += 1
# Add weight decay if any
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Momentum
exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
# RMS
exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)
mhat = exp_avg / (1 - b1 ** state['step'])
vhat = exp_avg_sq / (1 - b2 ** state['step'])
denom = torch.sqrt( vhat + group['eps'] )
p.data = p.data - group['lr'] * mhat / denom
# Save state
state['exp_avg'], state['exp_avg_sq'] = exp_avg, exp_avg_sq
return loss
Upvotes: 0