Reputation: 4118
The only difference is one of the parameter passed to DataLoader is in type "numpy.array" and the other is in type "list", but the DataLoader gives totally different results.
You can use the following code to reproduce it:
from torch.utils.data import DataLoader,Dataset
import numpy as np
class my_dataset(Dataset):
def __init__(self,data,label):
self.data=data
self.label=label
def __getitem__(self, index):
return self.data[index],self.label[index]
def __len__(self):
return len(self.data)
train_data=[[1,2,3],[5,6,7],[11,12,13],[15,16,17]]
train_label=[-1,-2,-11,-12]
########################### Look at here:
test=DataLoader(dataset=my_dataset(np.array(train_data),train_label),batch_size=2)
for i in test:
print ("numpy data:")
print (i)
break
test=DataLoader(dataset=my_dataset(train_data,train_label),batch_size=2)
for i in test:
print ("list data:")
print (i)
break
The result is:
numpy data:
[tensor([[1, 2, 3],
[5, 6, 7]]), tensor([-1, -2])]
list data:
[[tensor([1, 5]), tensor([2, 6]), tensor([3, 7])], tensor([-1, -2])]
Upvotes: 8
Views: 8556
Reputation: 1080
This is because how batching is handled in torch.utils.data.DataLoader
. collate_fn
argument decides how samples from samples are merged into a single batch. Default for this argument is undocumented torch.utils.data.default_collate
.
This function handles batching by assuming numbers/tensors/ndarrays are primitive data to batch and lists/tuples/dicts containing these primitives as structure to be (recursively) preserved. This allow you to have a semantic batching like this:
(input_tensor, label_tensor) -> (batched_input_tensor, batched_label_tensor)
([input_tensor_1, input_tensor_2], label_tensor) -> ([batched_input_tensor_1, batched_input_tensor_2], batched_label_tensor)
{'input': input_tensor, 'target': target_tensor} -> {'input': batched_input_tensor, 'target': batched_target_tensor}
(Left side of ->
is output of dataset[i], while right side is batched sample from torch.utils.data.DataLoader
)
Your example code is similar to example 2 above: list structure is preserved while int
s are batched.
Upvotes: 13