Statham
Statham

Reputation: 4118

Why pytorch DataLoader behaves differently on numpy array and list?

The only difference is one of the parameter passed to DataLoader is in type "numpy.array" and the other is in type "list", but the DataLoader gives totally different results.

You can use the following code to reproduce it:

from torch.utils.data import DataLoader,Dataset
import numpy as np

class my_dataset(Dataset):
    def __init__(self,data,label):
        self.data=data
        self.label=label          
    def __getitem__(self, index):
        return self.data[index],self.label[index]
    def __len__(self):
        return len(self.data)

train_data=[[1,2,3],[5,6,7],[11,12,13],[15,16,17]]
train_label=[-1,-2,-11,-12]

########################### Look at here:    

test=DataLoader(dataset=my_dataset(np.array(train_data),train_label),batch_size=2)
for i in test:
    print ("numpy data:")
    print (i)
    break


test=DataLoader(dataset=my_dataset(train_data,train_label),batch_size=2)
for i in test:
    print ("list data:")
    print (i)
    break

The result is:

numpy data:
[tensor([[1, 2, 3],
        [5, 6, 7]]), tensor([-1, -2])]
list data:
[[tensor([1, 5]), tensor([2, 6]), tensor([3, 7])], tensor([-1, -2])]  

Upvotes: 8

Views: 8556

Answers (1)

Sasank Chilamkurthy
Sasank Chilamkurthy

Reputation: 1080

This is because how batching is handled in torch.utils.data.DataLoader. collate_fn argument decides how samples from samples are merged into a single batch. Default for this argument is undocumented torch.utils.data.default_collate.

This function handles batching by assuming numbers/tensors/ndarrays are primitive data to batch and lists/tuples/dicts containing these primitives as structure to be (recursively) preserved. This allow you to have a semantic batching like this:

  1. (input_tensor, label_tensor) -> (batched_input_tensor, batched_label_tensor)
  2. ([input_tensor_1, input_tensor_2], label_tensor) -> ([batched_input_tensor_1, batched_input_tensor_2], batched_label_tensor)
  3. {'input': input_tensor, 'target': target_tensor} -> {'input': batched_input_tensor, 'target': batched_target_tensor}

(Left side of -> is output of dataset[i], while right side is batched sample from torch.utils.data.DataLoader)

Your example code is similar to example 2 above: list structure is preserved while ints are batched.

Upvotes: 13

Related Questions