Umair Javaid
Umair Javaid

Reputation: 471

How to use two seperate dataloaders together?

I have two tensors of shape (16384,3,224,224) each. I need to multiply these two together. Obviously these two tensors are too big to fit in GPU ram. So I want to know, how should I go about this, divide them smaller batches using slicing or should I use two separate dataloaders?(I am confused, how to use two different dataloader together) What would be the best way to do this?

Upvotes: 0

Views: 563

Answers (1)

jodag
jodag

Reputation: 22224

I'm still not sure I totally understand the problem, but under the assumption that you have two big tensors t1 and t2 of shape [16384, 3, 224, 224] already loaded in RAM and want to perform element-wise multiplication then the easiest approach is

result = t1 * t2

Alternatively you could break these into smaller tensors and multiply them that way. There are lots of ways to do this.

One very PyTorch like way is to use a TensorDataset and operate on corresponding mini-batches of both tensors. If all you want to do is element-wise multiplication then the overhead of transferring tensors to and from the GPU is likely more expensive than the actual time saved during the computation. If you want to try it you can use something like this

import torch
from torch.utils import data

batch_size = 100
device = 'cuda:0'

dataset = data.TensorDataset(t1, t2)
dataloader = data.DataLoader(dataset, num_workers=1, batch_size=batch_size)

result = []
for d1, d2 in dataloader:
    d1, d2 = d1.to(device=device), d2.to(device=device)
    d12 = d1 * d2

    result.append(d12.cpu())

result = torch.cat(result, dim=0)

Or you could just do some slicing, which will probably be faster and more memory efficient since it avoids data copying on the CPU side.

import torch

batch_size = 100
device = 'cuda:0'

index = 0
result = []
while index < t1.shape[0]:
    d1 = t1[index:index + batch_size].to(device=device)
    d2 = t2[index:index + batch_size].to(device=device)
    d12 = d1 * d2

    result.append(d12.cpu())
    index += batch_size

result = torch.cat(result, dim=0)

Note that for both of these examples most of the time is spent copying data back to the CPU and concatenating the final results. Ideally, you would just do whatever you need to do with the d12 batch within the loop and avoid ever sending the final multiplied result back to the CPU.

Upvotes: 1

Related Questions