user42
user42

Reputation: 959

How to ensure a certain datapoint is not in the test set in stratified cross validation split?

I have a DataFrame that looks like this.

d = {'col1': [1, 2,3,4,5,6,7,8], 'col2': ['a', 'a','b', 'b', 'c', 'c', 'd', 'd']}
df = pd.DataFrame(data=d)
df

  col1  col2
0   1     a
1   2     a
2   3     b
3   4     b
4   5     c
5   6     c
6   7     d
7   8     d

When I use k-fold cross validation, I want to ensure the values in col2 are present either only in the train set or in the test set. That is, during the split, if df['col2'][0] = a, and df['col2'][1] = a, then the rows with index 0 and 1 should both be in the train set, else in the test set. It should not be such that row 0 is in the train set and row 1 is in the test set.

Is there an easy way to do this?

Edit: Is there a way to just split the DataFrame into two such that each part contains all the data points that have value a in col2 in the first DataFrame or the second but not both? I tried using groupby but it returns an object and when I convert it to a dictionary, I am able to access it only by the keys, i.e a, b, c, d

Upvotes: 1

Views: 130

Answers (2)

user42
user42

Reputation: 959

With the help of @Antoine Dubuis, I found an sklearn implementation of what I wanted to do - called StratifiedGroupKFold.

It is still in development as of July 2021, but can be used from the development/nightly version. I advise creating a separate virtual environment to use it.

I have used it and it seems to work currently, so hope it will be released in a stable release soon.

Upvotes: 0

Antoine Dubuis
Antoine Dubuis

Reputation: 5334

You can ensure that a variable value is present only on a set by performing a GroupShuffleSplit as follows:

from sklearn.model_selection import GroupShuffleSplit

import pandas as pd

d = {'col1': [1, 2,3,4,5,6,7,8],
    'col2': ['a', 'a','b', 'b', 'c', 'c', 'd', 'd'],
    'label': [1,1,1,1,0,0,0,0]}
df = pd.DataFrame(data=d)

X = df[['col1', 'col2']]
y = df['label']
groups= df['col2']
gss = GroupShuffleSplit(n_splits=2, train_size=.8, random_state=42)
for train_idx, test_idx in gss.split(X, y, groups):
    X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
    y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

Upvotes: 2

Related Questions