Reputation: 121
(Short reproducible code below.)
A very bizarre behavior while using torch.Conv1d: when feeding a smaller input (below some threshold), the GPU memory usage spikes sharply during backward - an order of magnitude or more.
We hypothesize that this has to do with torch/cuda using different algorithms for the convolution depending on dimensions and available memory; problem is, that this results very undesired OOM errors on runtime.
import torch
from torch import nn
%load_ext pytorch_memlab
base_ch = 512
d_pos = 64
def print_gpu_mem_usage(prefix=""):
print(f"{prefix}Peak memory: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB"
f" | {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB"
f" (Current: {torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB"
f" | {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB)")
def isolated_conv(v):
samp_conv = nn.Conv1d(base_ch + d_pos, 2 * base_ch, kernel_size=1, padding='valid').cuda()
mn = samp_conv(v).mean()
mn.backward()
%mlrun -f isolated_conv isolated_conv(torch.rand(5000, base_ch+d_pos, 11).cuda())
:active_bytes: | :reserved_bytes: | :line: | :code: |
---|---|---|---|
all | all | ||
----peak---- | -----peak----- | ---- | ---- |
108.00M | 108.00 | 6 | def isolated_conv(v): |
328.00M | 346.00M | 7 | mn = nn.Conv1d(.....) |
542.00M | 562.00M | 8 | mn.backward() |
however, switch n samples from 5000 to 4000 and it explodes:
%mlrun -f isolated_conv isolated_conv(torch.rand(4000, base_ch+d_pos, 11).cuda())
:active_bytes: | :reserved_bytes: | :line: | :code: |
---|---|---|---|
all | all | ||
----peak---- | -----peak----- | ---- | ---- |
86.00M | 86.00 | 6 | def isolated_conv(v): |
260.00M | 280.00M | 7 | mn = nn.Conv1d(.....) |
8.07G | 8.25G | 8 | mn.backward() |
Same happens also if I test the two on opposite order.
This runs on a docker, so if you can't reproduce with the following versions I can share a Dockerfile.
edit/add-it: I read now about torch.backends.cudnn.deterministic=True
which slightly changes this behavior, still the problem occurs when I go down to 3000 samples.
Upvotes: 0
Views: 94