Auss
Auss

Reputation: 491

PyTorch DataLoader returns the batch as a list with the batch as the only entry. How is the best way to get a tensor from my DataLoader

I currently have the following situation where I want to use DataLoader to batch a numpy array:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])
print(x.shape)
# >> (100,10)

# Create DataLoader
input_as_tensor = torch.from_numpy(x).float()
dataset = data_utils.TensorDataset(input_as_tensor)
dataloader = data_utils.DataLoader(dataset,
                                   batch_size=100,
                                  )
batch = next(iter(dataloader))

print(type(batch))
# >> <class 'list'>

print(len(batch))
# >> 1

print(type(batch[0]))
# >> class 'torch.Tensor'>

I expect the batchto be already a torch.Tensor. As of now I index the batch like so, batch[0] to get a Tensor but I feel this is not really pretty and makes the code harder to read.

I found that the DataLoader takes a batch processing function called collate_fn. However, setting data_utils.DataLoader(..., collage_fn=lambda batch: batch[0]) only changes the list to a tuple (tensor([ 0.8454, ..., -0.5863]),) where the only entry is the batch as a Tensor.

You would help me a lot by helping me finding out how to elegantly transform the batch to a tensor (even if this would include telling me that indexing the single entry in batch is okay).

Upvotes: 3

Views: 8989

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24691

Sorry for inconvenience with my answer.

Actually, you don't have to create Dataset from your tensor, you can pass torch.Tensor directly as it implements __getitem__ and __len__, so this is sufficient:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])

# Create DataLoader
dataset = torch.from_numpy(x).float()
dataloader = data_utils.DataLoader(dataset, batch_size=100)
batch = next(iter(dataloader))

Upvotes: 1

Related Questions