Reputation: 145
I am trying to implement a model in PyTorch. The training procedure is quite complex and take a while, but what I have noticed is that the model is very fast on the first few batches, and then suddenly gets about 500. I guess it is due to some memory leak issue, as if python was not really letting free the memory of released huge tensors.
At first I thought that the problem was linked to the storing gradient, but actually even with torch.no_grad()
the same issue appears.
Here is an example to replicate the problem. (Note I am not trying to train this specific network, but the problem looks the same). To make things simpler I am not using the gradient and I am iterating on the same batch.
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as T
dataset = MNIST(root='./MNIST', train=True, download=True,
transform=T.Compose([T.ToTensor(), T.Lambda(lambda x: torch.flatten(x))]))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=500)
X, _ = next(iter(data_loader))
X = X.to('cuda')
in_features = 28*28
out_features = 10
width= 15000
#defining huge network
NN = nn.Sequential(
nn.Linear(in_features=28*28, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=width, bias=False),
nn.ReLU(),
nn.Linear(in_features=width, out_features=out_features, bias=False),
).to('cuda')
import time
iterations=100
X = X.to('cuda')
with torch.no_grad():
for idx in range(iterations):
print(f'Iteration {idx+1}')
start = time.time()
Y = NN(X)
print(f'Time: {time.time() - start}')
The output shows that everything is very fast up to almost the 50th iteration, then it suddenly slows down.
Iteration 44
Time: 0.00035953521728515625
Iteration 45
Time: 0.00035309791564941406
Iteration 46
Time: 0.00035309791564941406
Iteration 47
Time: 0.048192501068115234
Iteration 48
Time: 0.1714644432067871
Iteration 49
Time: 0.16771984100341797
Iteration 50
Time: 0.1681973934173584
Iteration 51
Time: 0.16853046417236328
Iteration 52
Time: 0.16821908950805664
Why is there such a slow down? Is it possible to avoid it somehow?
Upvotes: 2
Views: 2630
Reputation: 2126
Check out this page and scroll down to "Asynchronous execution".
Basically, you are measuring the time to enqueue your operation into the GPU not the time it actually takes to execute your operations. This is because GPU calls are asynchronous as described in the link. I copied the relevant part below:
By default, GPU operations are asynchronous. When you call a function that uses the GPU, the operations are enqueued to the particular device, but not necessarily executed until later. This allows us to execute more computations in parallel, including operations on CPU or other GPUs.
In general, the effect of asynchronous computation is invisible to the caller, because (1) each device executes operations in the order they are queued, and (2) PyTorch automatically performs necessary synchronization when copying data between CPU and GPU or between two GPUs. Hence, computation will proceed as if every operation was executed synchronously.
You can force synchronous computation by setting environment variable CUDA_LAUNCH_BLOCKING=1. This can be handy when an error occurs on the GPU. (With asynchronous execution, such an error isn’t reported until after the operation is actually executed, so the stack trace does not show where it was requested.)
A consequence of the asynchronous computation is that time measurements without synchronizations are not accurate. To get precise measurements, one should either call torch.cuda.synchronize() before measuring, or use torch.cuda.Event
to record times as following:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
# Run some things here
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
Upvotes: 1