Freeziey
Freeziey

Reputation: 19

OutOfMemoryError: CUDA out of memory with a gan

I have been trying to setup this gan but i cannot escape this error: https://pastebin.com/rp0kNXiH (sometimes it's not the same line)

I have tried gc.collect but maybe i misplaced it. I also tried to have as little as possible things on the gpu in my training loop but nothing seems to work.

I code on a kaggle notebook.

here's my code: https://paste.pythondiscord.com/IVMQ

import torch
import os
from torch import nn
import numpy as np
import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchsummary import summary
import itertools
from tqdm import tqdm
import gc


device = ""
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


X = []
Y = []

c = 0
data_size = 1024
print(torch.cuda.mem_get_info())
dirname = "../input/mushrooms/train"
for filename in tqdm(os.listdir(dirname)):
    if filename != "_classes.csv":
        im  = Image.open(os.path.join(dirname, filename))
        im1 = im.resize((512,512))
        Y.append(np.array(im1))
        im2 = im.resize((128,128))
        X.append(np.array(im2))
    
    c+=1
    if(c == data_size-1):
        break

X = np.array(X, dtype = 'float32')
Y = np.array(Y, dtype = 'float32')

X = torch.tensor(X)
Y = torch.tensor(Y)

X = torch.transpose(X,1,3)
Y = torch.transpose(Y,1,3)

batch_size = 32

#attention c'est toujours le même ordre de data
low_loader = torch.utils.data.DataLoader(
    X, batch_size=batch_size#, shuffle=True
)

high_loader = torch.utils.data.DataLoader(
    Y, batch_size=batch_size#, shuffle=True
)

print("data loaded in the ram")
print(torch.cuda.mem_get_info())

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential( #améliorations possible: batchnorm, dropout
            #input: 3x512x512
            nn.Conv2d(3,64,3,padding=1), 
            nn.ReLU(),
            nn.Conv2d(64,64,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            
            #input: 64x256x256
            nn.Conv2d(64,128,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(128,128,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            #input: 128x128x128
            nn.Conv2d(128,256,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            #input: 256x64x64
            nn.Conv2d(256,512,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            #input: 512x32x32
            nn.Conv2d(512,512,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            #input: 512x16x16
            nn.Flatten(),
            nn.Linear(512*16*16, 10),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(10,5),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(5,1),
            nn.Sigmoid()   
        )

    def forward(self, x):
        return self.model(x)

discriminator = Discriminator().to(device=device)
summary(discriminator.model, (3, 512, 512))
print("disriminator set")
print(torch.cuda.mem_get_info())

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        self.mp = nn.MaxPool2d(2,stride=2)

        #input: 3x128x128
        self.c1 = nn.Sequential(
            nn.Conv2d(3,64,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,64,3,padding=1),
            nn.ReLU())

        #input: 64x64x64
        self.c2 = nn.Sequential(
            nn.Conv2d(64,128,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(128,128,3,padding=1),
            nn.ReLU())

        #input: 128x32x32
        self.c3 = nn.Sequential(
            nn.Conv2d(128,256,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,3,padding=1),
            nn.ReLU())

        #input: 256x16x16
        self.c4 = nn.Sequential(
            nn.Conv2d(256,512,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),
            nn.ReLU())

        #input: 512x8x8
        self.c5 = nn.Sequential(
            nn.Conv2d(512,1024,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(1024,1024,3,padding=1),
            nn.ReLU())

        # Decoder
        #input: 1024x8x8
        self.tc1 = nn.ConvTranspose2d(1024,512,2,stride=2)
        #output: 512x16x16
        #skip connection
        #input: 1024x16x16
        self.u1 = nn.Sequential(
            nn.Conv2d(1024,512,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(512,512,3,padding=1),
            nn.ReLU())
        #output: 512x16x16

        #input: 512x16x16
        self.tc2 = nn.ConvTranspose2d(512,256,2,stride=2)
        #output: 256x32x32
        #skip connection
        #input: 512x32x32
        self.u2 = nn.Sequential(
            nn.Conv2d(512,256,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,3,padding=1),
            nn.ReLU())
        #output: 256x32x32
        
        #input: 256x32x32
        self.tc3 = nn.ConvTranspose2d(256,128,2,stride=2)
        #output: 128x64x64
        #skip connection
        #input: 256x64x64
        self.u3 = nn.Sequential(
            nn.Conv2d(256,128,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(128,128,3,padding=1),
            nn.ReLU())
        #output: 128x64x64

        #input: 128x64x64
        self.tc4 = nn.ConvTranspose2d(128,64,2,stride=2)
        #output: 64x128x128
        #skip connection
        #input: 128x128x128
        self.u4 = nn.Sequential(
            nn.Conv2d(128,64,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(64,64,3,padding=1),
            nn.ReLU())
        #output: 64x128x128

        #input: 64x128x128
        self.tc5 = nn.ConvTranspose2d(64,32,2,stride=2)
        #output: 32x256x256
        #input: 32x256x256
        self.u5 = nn.Sequential(
            nn.Conv2d(32,32,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(32,32,3,padding=1),
            nn.ReLU())
        #output: 32x256x256

        #input: 32x256x256
        self.tc6 = nn.ConvTranspose2d(32,16,2,stride=2)
        #output: 16x512x512
        #input: 16x512x512
        self.u6 = nn.Sequential(
            nn.Conv2d(16,16,3,padding=1),
            nn.ReLU(),
            nn.Conv2d(16,16,3,padding=1),
            nn.ReLU())
        #output: 16x512x512
        
        #input: 16x512x512
        self.last = nn.Conv2d(16,3,1)

    
    def forward(self, x):
        
        x1 = self.c1(x)
        x1p = self.mp(x1)

        x2 = self.c2(x1p)
        x2p = self.mp(x2)

        x3 = self.c3(x2p)
        x3p = self.mp(x3)

        x4 = self.c4(x3p)
        x4p = self.mp(x4)

        x5 = self.c5(x4p)

        xt1 = self.tc1(x5)
        xc1 = self.u1(torch.cat([x4, xt1], dim=1))

        xt2 = self.tc2(xc1)
        xc2 = self.u2(torch.cat([x3, xt2], dim=1))

        xt3 = self.tc3(xc2)
        xc3 = self.u3(torch.cat([x2, xt3], dim=1))

        xt4 = self.tc4(xc3)
        xc4 = self.u4(torch.cat([x1, xt4], dim=1))

        xt5 = self.tc5(xc4)
        xc5 = self.u5(xt5)

        xt6 = self.tc6(xc5)
        xc6 = self.u6(xt6)
        
        return self.last(xc6)


generator = Generator().to(device=device)
print("generator set")
print(torch.cuda.mem_get_info())

lr = 0.0001
num_epochs = 15
loss_function = nn.BCELoss

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)



for epoch in range(num_epochs):
    for low_samples,high_samples in itertools.zip_longest(enumerate(low_loader), enumerate(high_loader)):
        low_samples = low_samples[1]
        high_samples = high_samples[1]

        generated_labels =  torch.zeros((batch_size, 1))
        generated_samples = generator(low_samples.to(device=device))

        high_samples_labels = torch.ones((batch_size, 1))
        all_samples = torch.cat((high_samples, generated_samples.to('cpu')))
        all_samples_labels = torch.cat((high_samples_labels, generated_labels))

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples.to(device=device))
        loss_discriminator = loss_function(output_discriminator, all_samples_labels.to(device=device))

        loss_discriminator.backward() #compute gradient
        optimizer_discriminator.step() #backpropagation


        # Training the generator
        generator.zero_grad()
        generated_output = generator(low_samples.to(device=device))
        discriminated_generated_output = discriminator(generated_output)
        loss_generator = loss_function(discriminated_generated_output,high_samples_labels.to(device=device))

        loss_generator.backward() #compute gradient
        optimizer_generator.step() #backpropagation
    
    
    print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
    print(f"Epoch: {epoch} Loss G.: {loss_generator}")

the error :

OutOfMemoryError                          Traceback (most recent call last)
<ipython-input-4-aa3f31ef5eca> in <cell line: 10>()
     22         # Training the discriminator
     23         discriminator.zero_grad()
---> 24         output_discriminator = discriminator(all_samples.to(device=device))
     25         loss_discriminator = loss_function(output_discriminator, all_samples_labels.to(device=device))
     26 
 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):
 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1560                 or _global_backward_pre_hooks or _global_backward_hooks
   1561                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562             return forward_call(*args, **kwargs)
   1563 
   1564         try:
 
<ipython-input-2-abdec4b3329d> in forward(self, x)
     57 
     58     def forward(self, x):
---> 59         return self.model(x)
     60 
     61 discriminator = Discriminator().to(device=device)
 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):
 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1560                 or _global_backward_pre_hooks or _global_backward_hooks
   1561                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562             return forward_call(*args, **kwargs)
   1563 
   1564         try:
 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py in forward(self, input)
    217     def forward(self, input):
    218         for module in self:
--> 219             input = module(input)
    220         return input
    221 
 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):
 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1560                 or _global_backward_pre_hooks or _global_backward_hooks
   1561                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562             return forward_call(*args, **kwargs)
   1563 
   1564         try:
 
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/activation.py in forward(self, input)
    102 
    103     def forward(self, input: Tensor) -> Tensor:
--> 104         return F.relu(input, inplace=self.inplace)
    105 
    106     def extra_repr(self) -> str:
 
/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in relu(input, inplace)
   1498         result = torch.relu_(input)
   1499     else:
-> 1500         result = torch.relu(input)
   1501     return result
   1502 
 
OutOfMemoryError: CUDA out of memory. Tried to allocate 4.00 GiB. GPU 0 has a total capacity of 14.74 GiB of which 1.71 GiB is free. Process 7476 has 13.03 GiB memory in use. Of the allocated memory 12.32 GiB is allocated by PyTorch, and 595.48 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)```

Upvotes: 1

Views: 49

Answers (1)

David Robinson
David Robinson

Reputation: 356

The issue seems to be with how the data is inefficiently handled during training, specifically with the creation of large tensors and unnecessary device transfers during training. The error states that PyTorch is already using 12.3 GiB of GPU memory and attempts to allocate an additional 4 GiB with all_samples.to(device=device). At each epoch, you move generated_samples to the GPU and then move it back to the CPU for all_samples and then back to the GPU, which increases overhead and memory usage as Python doesn't clear the unused memory (i.e. generated_samples) unless it is explicitly freed.

  1. Ensure that samples are only transferred to the GPU once and there isn't duplicates in GPU memory, such as generated_samples and all_samples.
  2. Avoid concatenating large tensors to create all_samples and all_samples_labels by training the discriminator separately on real and fake samples
real_output = discriminator(high_samples.to(device))
real_loss = loss_function(real_output, torch.ones((batch_size, 1)).to(device))

# Train discriminator on fake samples
fake_output = discriminator(generator(low_samples.to(device)).detach())
fake_loss = loss_function(fake_output, torch.zeros((batch_size, 1)).to(device))

loss_discriminator = real_loss + fake_loss
  1. If each sample still overloads the GPU, reduce the batch size to 8 or 16 to reduce the per-batch memory footprint.
  2. Clear unused memory between batches with the following code at the beginning of the loop:
torch.cuda.empty_cache()
gc.collect()

Upvotes: 0

Related Questions