Reputation: 167
My goal is to make binary classification, using neural network. The problem is that dataset is unbalanced, I have 90% of class 1 and 10 of class 0. To deal with it I want to use Stratified cross-validation.
The problem that is I am working with Pytorch, I can't find any example and documentation doesn't provide it, and I'm student, quite new for neural networks.
Can anybody help? Thank you!
Upvotes: 2
Views: 7263
Reputation: 19834
The easiest way I've found is to do you stratified splits before passing your data to Pytorch Dataset
and DataLoader
. That lets you avoid having to port all your code to skorch, which can break compatibility with some cluster computing frameworks.
Upvotes: 2
Reputation: 3548
Have a look at skorch. It's a scikit-learn compatible neural network library that wraps PyTorch. It has a function CVSplit for cross validation or you can use sklearn. From the docs:
net = NeuralNetClassifier(
module=MyModule,
train_split=None,
)
from sklearn.model_selection import cross_val_predict
y_pred = cross_val_predict(net, X, y, cv=5)
Upvotes: 1