Reputation: 2290
I need to backpropagate through my neural network multiple times, so I set backward(retain_graph=True)
.
However, this is causing
RuntimeError: CUDA out of memory
I don't understand why this is.
Are the number of variables or weights doubling? Shouldn't the amount of memory used remain the same regardless of how many times backward()
is called?
Upvotes: 2
Views: 5221
Reputation: 84
The source of the issue :
You are right that no matter how many times we call the backward function, the memory should not increase theorically.
Yet your issue is not because of the backpropagation, but the retain_graph variable that you have set to true when calling the backward function.
When you run your network by passing a set of input data, you call the forward function, which will create a "computation graph". A computation graph is containing all the operations that your network has performed.
Then when you call the backward function, the computation graph saved will "basically" be runned backward to know which weight should be adjusted in which directions (what is called the gradients). So PyTorch is saving in memory the computation graph in order to call the backward function.
After the backward function has been called and the gradients have been calculated, we free the graph from the memory, as explained in the doc https://pytorch.org/docs/stable/autograd.html :
retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.
Then usually during training we apply the gradients to the network in order to minimise the loss, then we re-run the network, and so we create a new computation graph. Yet we have only one graph in memory at the same time.
The issue :
If you set retain_graph to true when you call the backward function, you will keep in memory the computation graphs of ALL the previous runs of your network.
And since on every run of your network, you create a new computation graph, if you store them all in memory, you can and will eventually run out of memory.
On the first iteration and run of your network, you will have only one graph in memory. Yet on the 10th run of the network, you have 10 graphs in memory. And on the 10000th run you have 10000 in memory. It is not sustainable, and it is understandable why it is not recommended in the docs.
So even if it may seems that the issue is the backpropagation, it is actually the storing of the computation graphs, and since we usually call the the forward and backward function once per iteration or network run, making a confusion is understandable.
Solution :
What you need to do, is find a way to make your network and architecture work without using retain_graph. Using it will make it almost impossible to train your network, since each iteration increase the usage of your memory and decrease the speed of training, and in your case, even cause you to run out of memory.
You did not mention why you need to backpropagate multiple times, yet it is rarely needed, and i do not know of a case where it cannot be "worked around". For example, if you need to access variables or weights of previous runs you could save them inside variables and later access them, instead of trying doing a new backpropagation.
You likely need to backpropagate multiple times for another reason, yet believe as i have been in this situation, there is likely a way to accomplish what you are trying to do without storing the previous computation graphs.
If you want to share why you need to backpropagate multiple times, maybe others and i could help you more.
More about the backward process :
If you want to learn more about the backward process it is called the "Jacobian-vector product". It is a bit complex and is handled by PyTorch. I do not yet fully understand it, yet this ressource seems good as a starting point, as it seems less intimidating than the PyTorch documentation (in term of algebra) : https://mc.ai/how-pytorch-backward-function-works/
Upvotes: 6