Roi
Roi

Reputation: 121

torch Conv1d GPU memory spikes during backward prop. on smaller inputs

(Short reproducible code below.)

A very bizarre behavior while using torch.Conv1d: when feeding a smaller input (below some threshold), the GPU memory usage spikes sharply during backward - an order of magnitude or more.

We hypothesize that this has to do with torch/cuda using different algorithms for the convolution depending on dimensions and available memory; problem is, that this results very undesired OOM errors on runtime.

import torch
from torch import nn
%load_ext pytorch_memlab

base_ch = 512
d_pos = 64


def print_gpu_mem_usage(prefix=""):
    print(f"{prefix}Peak memory: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB"
          f" | {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB"
          f" (Current: {torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB"
          f" | {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB)")


def isolated_conv(v):
    samp_conv = nn.Conv1d(base_ch + d_pos, 2 * base_ch, kernel_size=1, padding='valid').cuda()
    mn = samp_conv(v).mean()
    mn.backward()

%mlrun -f isolated_conv isolated_conv(torch.rand(5000, base_ch+d_pos, 11).cuda())

:active_bytes: :reserved_bytes: :line: :code:
all all
----peak---- -----peak----- ---- ----
108.00M 108.00 6 def isolated_conv(v):
328.00M 346.00M 7 mn = nn.Conv1d(.....)
542.00M 562.00M 8 mn.backward()

however, switch n samples from 5000 to 4000 and it explodes:

%mlrun -f isolated_conv isolated_conv(torch.rand(4000, base_ch+d_pos, 11).cuda())

:active_bytes: :reserved_bytes: :line: :code:
all all
----peak---- -----peak----- ---- ----
86.00M 86.00 6 def isolated_conv(v):
260.00M 280.00M 7 mn = nn.Conv1d(.....)
8.07G 8.25G 8 mn.backward()

Same happens also if I test the two on opposite order.

This runs on a docker, so if you can't reproduce with the following versions I can share a Dockerfile.

edit/add-it: I read now about torch.backends.cudnn.deterministic=True which slightly changes this behavior, still the problem occurs when I go down to 3000 samples.

Upvotes: 0

Views: 94

Answers (0)

Related Questions