Reputation: 1565
Currently I'm writing a segmentation model based on U-net with pytorch and I want to use something similar to inverted residual introduced on mobilenet v2 to improve the model's speed on cpu. pytorch code for mobile netv2
Then I realize that the model uses a lot more memory on train phase and test phase. While the model should use more memory on train phase because all the mid-step tensors(feature maps) are saved and with separable convolution there are more tensors created for each "convolution" operation. But on run time, actually only few last step tensors must be saved to be used for skip connection and all the other tensors can be deleted once it's next step is created. The memory efficiency should be the same for u-net with normal convolution and u-net with separable convolution on test phase.
I'm newbee to pytorch so I don't know how to write code that prevents unnecessary memory cost on test time. Since pytorch is binded with python. I guess I can manually delete all the unnecessary tensors in forward function with del. But I guess that if I just delete variables on forward function, it will influence training stage. Is here more advanced functionality on pytorch that is able to optimize test phase memory usage with a 'network graph'? I'm also curious if tensorflow deals with those problems automatically since it has a more abstract and complex graph building logic.
Upvotes: 0
Views: 1636
Reputation: 1565
After reading the official pytorch code for resnet, I realize I shouldn't give all variables a name.aka I shouldn't write:
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
I should just write:
out = self.conv1(x)
out = self.conv2(out)
On this way nothing refers to obj corresponds to conv1 after it is used and python is able to clean it.
because there are residual connections between blocks, I need to have one more python variable to refer to the variable: aka
out = self.conv1(x)
residual_connect = out
out = self.conv2(out)
out = conv1 + out
But on upsampling stage only out is needed. So I deleted residual_connect at the beginning of the decoding stage.
del residual_connect
It seems like a hack and I'm surprised that it didn't cause problem on training stage. The ram usage for my model is greatly reduced now but I feel here should be a more elegant way to solve the problem.
Upvotes: 2
Reputation: 37741
Since you have used torch.no_grad()
during testing, automatically you are asking the Context-manager to disable gradient calculation which results in less memory usage than training. However, I found that the caching allocator occupies a lot of memory which can be released after each update to your model during training which saves a lot of memory.
So, you can use the function torch.cuda.empty_cache() and I found it really helpful in my case. Also, reading the memory management can help you to learn other important things about GPU memory management in PyTorch.
Upvotes: 0