Reputation: 12669
I am trying to understand grad() function in python, I know about backpropagation but having some doubt in .grad() function result.
So if i have a very simple network say with one single input and one single weight :
import torch
from torch.autograd import Variable
from torch import FloatTensor
a_tensor=Variable(FloatTensor([1]))
weight=Variable(FloatTensor([1]),requires_grad=True)
Now i am running this in ipython cell:
net_out=a_tensor*weight
loss=5-net_out
loss.backward()
print("atensor",a_tensor)
print('weight',weight)
print('net_out',net_out)
print('loss',loss)
print(weight.grad)
During first run it returns :
atensor tensor([ 1.])
weight tensor([ 1.])
net_out tensor([ 1.])
loss tensor([ 4.])
tensor([-1.])
Which is correct because if i am right then computing gradient equation would be here :
Now netout/w would be (w*a) w.r.t to w ==> 1*a
And loss/netout (5-netout) w.r.t to netout ==> (0-1)
Which would be 1*a*-1 ==> -1
But problem is if i press same cell again without modifying anything then i am getting grad -2 , -3 ,-4 ...etc
atensor tensor([ 1.])
weight tensor([ 1.])
net_out tensor([ 1.])
loss tensor([ 4.])
tensor([-2.])
next run:
atensor tensor([ 1.])
weight tensor([ 1.])
net_out tensor([ 1.])
loss tensor([ 4.])
tensor([-3.])
so on..
I am not getting what's happening there why and how the value of grad is increasing?
Upvotes: 1
Views: 3490
Reputation: 1938
This is because you are not zeroing the gradients. What loss.backward()
does is accumulate gradients - it adds gradients to existing ones. If you don't zero the gradient, then running loss.backward()
over and over just keep adding the gradients to each other. What you want to do is zero the gradients after each step and you will see that the gradients are calculated correctly.
If you have built a network net
( which should be a nn.Module
class object), you can zero the gradients simply by calling net.zero_grad()
. If you haven't built a net
(or an torch.optim
object) you will have to zero the gradients yourself manually.
Use weight.grad.data.zero_()
method there.
Upvotes: 2