david Liang
david Liang

Reputation: 47

Why multiplication on GPU is slower than on CPU?

Here is my code (simulate the feed-forward neural network):

import torch
import time

print(torch.cuda.is_available())    # True
device = torch.device('cuda:0' )

a = torch.tensor([1,2,3,4,5,6]).float().reshape(-1,1)
w1 = torch.rand(120,6)
w2 = torch.rand(1,120)
b1 = torch.rand(120,1)
b2 = torch.rand(1,1).reshape(1,1)

start = time.time()
for _ in range(100000):
    ans = torch.mm(w2, torch.mm(w1,a)+b1)+b2
end = time.time()
print(end-start)                    # 1.2725720405578613 seconds

a = a.to(device)
w1 = w1.to(device)
w2 = w2.to(device)
b1 = b1.to(device)
b2 = b2.to(device)

start = time.time()
for _ in range(100000):
    ans = torch.mm(w2, torch.mm(w1,a)+b1)+b2
end = time.time()
print(end-start)                    # 5.6569812297821045 seconds

I wonder if I did it the wrong way or what, and how can I change my code to show that GPU IS faster then CPU on matrix multiplication?

Upvotes: 3

Views: 1834

Answers (2)

Nivesh Gadipudi
Nivesh Gadipudi

Reputation: 506

CPU to GPU transfer comes with an overhead. You also can observe that the first layer of model takes large amount of time when compared to the preceding ones.

Because, tensors transfers from Host memory to GPU memory at first. Then, the cuda cores perform operations on tensors in the CUDA memory.

Upvotes: 0

Károly Szabó
Károly Szabó

Reputation: 1273

The reason can be a lot of things:

  1. Your model is simple.
  2. For GPU calculation there is the cost of memory transfer to and from the GPU's memory
  3. You calculation is on a small data batch, probably with bigger data sample you should see better performance on GPU than CPU
  4. We should not forget the caching, you calculate the same operations over and over again, maybe would be better to generate random a tensors for every run

Here is a thread on the pytorch forum: https://discuss.pytorch.org/t/cpu-faster-than-gpu/25343

Also you should use better profiler, like explaind in this thread: https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964

Upvotes: 8

Related Questions