Reputation: 11
In pytorch, the backward() function accumulates gradients and we have to reset it every mini-batch by calling optimizer.zero_grad(). In this case, how does the SGD with momentum works when actually momentum SGD updates the weights using exponential average of some past mini-batches.
For a beginner in Pytorch, I am confused. Doesn't it require to have past gradients to perform updates.
Upvotes: 1
Views: 4863
Reputation: 114806
When using momentum you need to store a one-element history for each parameter, other solvers (e.g. ADAM) requires even more. The optimizer knows how to store this history data and accumuate new gradients in an orderly fashion. You do not have to worry about it.
So why zero_grad()
, you probably ask yourself?
well, sometimes an entire minibatch does not fit into GPU memory and you want to split its processing into several "mini"-minibatches, but without decreasing the effective batch size used for computing the gradients and weight updates.
In that case, you call zero_grad()
once, do forward
and backward
for all mini-minibatches and only then call optimizer.step()
- this step averages the gradients from all the mini-minibatches and you get an effective update as if you ran a single minibatch. See this thread for more details.
Some more information about gradients and optimizer in pytorch can be found here and here.
Upvotes: 2