Reputation: 4755
From the docs (https://scikit-learn.org/stable/modules/cross_validation.html#group-k-fold):
GroupKFold is a variation of k-fold which ensures that the same group is not represented in both testing and training sets
Then, slightly adapting the example, we have:
from sklearn.model_selection import GroupKFold
X = np.array([0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10])
y = np.array(["a", "b", "b", "b", "c", "c", "c", "d", "d", "d"])
groups = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
gkf = GroupKFold(n_splits=3)
for train, test in gkf.split(X, y, groups=groups):
print("%s %s" % (train, test))
Which prints:
[0 1 2 3 4 5] [6 7 8 9]
[0 1 2 6 7 8 9] [3 4 5]
[3 4 5 6 7 8 9] [0 1 2]
To me it seems that group b
is in both the testing and training sets here though, we have -
[3 4 5 6 7 8 9] [0 1 2]
For the last output, where the test indices are [0, 1, 2]
, which give us group a
and the two values from group b
, which means that there's a value from group b
in the test set as well as the training (which index 3
).
Presumably the docs / module are correct, and I'm wrong, but I don't understand how.
To be clear - I'm expecting not to see values of the same group in both testing an d training, and there is.
Upvotes: 2
Views: 1039
Reputation: 5174
You are mistaking the classes as the groups. As the comments already pointed out, they are however determined by the group
parameter only and are independent of the classes.
You can get a better understanding of the example by following the description you already linked to:
For example if the data is obtained from different subjects with several samples per-subject and if the model is flexible enough to learn from highly person specific features it could fail to generalize to new subjects.
So the problem GroupKFold
is designed for could be a situation where you have obtained data from different sources (subjects in the example) and want to control if your model has generalized well enough to perform well on data from other sources. Or in other words, you want to make sure that your model has not overfitted to data from a particular source or sources. And this is what GroupKFold
is made for:
GroupKFold
makes it possible to detect this kind of overfitting situation.
So these sources (or subjects) are determined by the group
parameter and will be separated by GroupKFold
so that the same source is never represented in both testing and training folds.
Upvotes: 3