Tengerye
Tengerye

Reputation: 1964

Memory leak in Pytorch: object detection

I am working on the object detection tutorial on PyTorch. The original tutorial works fine with the few epochs given. I expanded it to large epochs and encounter out of memory error.

I tried to debug it and find something interesting. This is the tool I am using:

def debug_gpu():
    # Debug out of memory bugs.
    tensor_list = []
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                tensor_list.append(obj)
        except:
            pass
    print(f'Count of tensors = {len(tensor_list)}.')

And I used it to monitor the memory of training one epoch:

def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    ...
    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        # inference + backward + optimization
        debug_gpu()

The output is something like this:

Count of tensors = 414.
Count of tensors = 419.
Count of tensors = 424.
Count of tensors = 429.
Count of tensors = 434.
Count of tensors = 439.
Count of tensors = 439.
Count of tensors = 444.
Count of tensors = 449.
Count of tensors = 449.
Count of tensors = 454.

As you can see, the count of tensors tracked by garbage collector increases constantly.

Relevant files to execute can be found here.

I have two questions: 1. What is holding up the garbage collector to release these tensors? 2. What should I do with the out of memory error?

Upvotes: 2

Views: 1240

Answers (1)

Tengerye
Tengerye

Reputation: 1964

  1. How I identify the error? With the help of tracemalloc, I take two snapshots with several hundred iterations between. The tutorial will show you it easy to follow.

  2. What cause the error? rpn.anchor_generator._cache in the Pytorch is a python dict which trace the grid anchors. It is an attribute with the detection model and the size increases with each proposal.

  3. How to solve it? An easy bypass is put model.rpn.anchor_generator._cache.clear() at the end of training iterations.


I have submit a fix to PyTorch. You may won't have the OOM error since torchvision 0.5.

Upvotes: 1

Related Questions