elenaby
elenaby

Reputation: 167

Stratified cross validation with Pytorch

My goal is to make binary classification, using neural network. The problem is that dataset is unbalanced, I have 90% of class 1 and 10 of class 0. To deal with it I want to use Stratified cross-validation.

The problem that is I am working with Pytorch, I can't find any example and documentation doesn't provide it, and I'm student, quite new for neural networks.

Can anybody help? Thank you!

Upvotes: 2

Views: 7263

Answers (2)

crypdick
crypdick

Reputation: 19834

The easiest way I've found is to do you stratified splits before passing your data to Pytorch Dataset and DataLoader. That lets you avoid having to port all your code to skorch, which can break compatibility with some cluster computing frameworks.

Upvotes: 2

Arigion
Arigion

Reputation: 3548

Have a look at skorch. It's a scikit-learn compatible neural network library that wraps PyTorch. It has a function CVSplit for cross validation or you can use sklearn. From the docs:

net = NeuralNetClassifier(
   module=MyModule,
   train_split=None,
)
from sklearn.model_selection import cross_val_predict
y_pred = cross_val_predict(net, X, y, cv=5)

Upvotes: 1

Related Questions