Stepan Yakovenko
Stepan Yakovenko

Reputation: 9206

Why pytorch isn't minimizing x*x for me?

I expect x to converge to 0, which is minimum of x*x. But this doesn't happen. What am I doing wrong in this small sample code:

import torch
from torch.autograd import Variable
tns = torch.FloatTensor([3])
x = Variable(tns, requires_grad=True)
z = x*x
opt = torch.optim.Adam([x], lr=.01, betas=(0.5, 0.999))
for i in range(3000):
    z.backward(retain_graph=True) # Calculate gradients
    opt.step()
    print(x)

Upvotes: 3

Views: 1781

Answers (1)

enumaris
enumaris

Reputation: 1938

The problem you have is that you don't zero the gradients when you are calculating each loop. Instead, by setting retain_graph=True and not calling opt.zero_grad() at each step of the loop you are actually adding the gradients calculated to ALL previous gradients calculated. So instead of taking a step in gradient descent, you are taking a step with respect to all accumulated gradients which is certainly NOT what you want.

You should instead make sure to call opt.zero_grad() at the beginning of your loop, and move the z=x*x inside your loop so that you don't have to retain_graph.

I made these slight modifications:

import torch
from torch.autograd import Variable
tns = torch.FloatTensor([3])
x = Variable(tns, requires_grad=True)
opt = torch.optim.Adam([x], lr=.01, betas=(0.5, 0.999))
for i in range(3000):
    opt.zero_grad()
    z = x*x
    z.backward() # Calculate gradients
    opt.step()
    print(x)

And my final x is 1e-25.

Upvotes: 5

Related Questions