Reputation: 471
I have two tensors of shape (16384,3,224,224)
each. I need to multiply these two together. Obviously these two tensors are too big to fit in GPU ram. So I want to know, how should I go about this, divide them smaller batches using slicing or should I use two separate dataloaders?(I am confused, how to use two different dataloader together)
What would be the best way to do this?
Upvotes: 0
Views: 563
Reputation: 22224
I'm still not sure I totally understand the problem, but under the assumption that you have two big tensors t1
and t2
of shape [16384, 3, 224, 224]
already loaded in RAM and want to perform element-wise multiplication then the easiest approach is
result = t1 * t2
Alternatively you could break these into smaller tensors and multiply them that way. There are lots of ways to do this.
One very PyTorch like way is to use a TensorDataset and operate on corresponding mini-batches of both tensors. If all you want to do is element-wise multiplication then the overhead of transferring tensors to and from the GPU is likely more expensive than the actual time saved during the computation. If you want to try it you can use something like this
import torch
from torch.utils import data
batch_size = 100
device = 'cuda:0'
dataset = data.TensorDataset(t1, t2)
dataloader = data.DataLoader(dataset, num_workers=1, batch_size=batch_size)
result = []
for d1, d2 in dataloader:
d1, d2 = d1.to(device=device), d2.to(device=device)
d12 = d1 * d2
result.append(d12.cpu())
result = torch.cat(result, dim=0)
Or you could just do some slicing, which will probably be faster and more memory efficient since it avoids data copying on the CPU side.
import torch
batch_size = 100
device = 'cuda:0'
index = 0
result = []
while index < t1.shape[0]:
d1 = t1[index:index + batch_size].to(device=device)
d2 = t2[index:index + batch_size].to(device=device)
d12 = d1 * d2
result.append(d12.cpu())
index += batch_size
result = torch.cat(result, dim=0)
Note that for both of these examples most of the time is spent copying data back to the CPU and concatenating the final results. Ideally, you would just do whatever you need to do with the d12
batch within the loop and avoid ever sending the final multiplied result back to the CPU.
Upvotes: 1