JanEgern
JanEgern

Reputation: 1

Randomness in torch.utils.data.random_split vs. numpy based splitting

I'm implementing a convolutional neural network in PyTorch using the PyTorch Lightning infrastructure.

My collaborators asked me to evaluate the across dataset performance using kfold cross validation (yes, I am aware of data leakage and overfitting).

My problem is that numpy based random splitting yields significantly worse performances than torch.utils.data.random_split. My test performance using the train/val/test approach is around AUC 0.95, using a 8:1:1 split.

The sample size used to calculate the AUC is around 10,000 so the AUC difference between the two approaches is significant.

My question is why the pytorch random_split function returns better results? In my numpy implementation, the probability of belonging to given folds are independent between observations. Is that the case for the random_split function or does it e.g., use random intervals of samples?

Kfold Experiments

When I subset my dataset using numpy i get performances around AUC=0.88. Here is my code:

fold    = 1
indices = np.arange(X.shape[0])
sampler = np.random.permutation(indices) % 5 # 5 fold CV
X_train, X_test = X[(sampler!=fold)], X[(sampler==fold)]
y_train, y_test = y[(sampler!=fold], y[(sampler==fold)]

dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train))
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])
test_dataset  = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test))

When I subset my dataset using random_split in pytorch i get performances around AUC=0.95. Here is my code:

fold = 1
dataset       = TensorDataset(torch.Tensor(X), torch.Tensor(y))
five_folds    = torch.utils.data.random_split(dataset, [1/5]*5) # 5 fold CV
test_dataset  = five_folds[fold]
train_dataset = torch.utils.data.ConcatDataset(
    [five_folds[i] for i in range(5) if i != fold]
)
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.9, 0.1])

Upvotes: 0

Views: 31

Answers (0)

Related Questions