Hantaa
Hantaa

Reputation: 93

"Stratify" parameter from sklearn's train_test_split not working correctly?

I have a problem with the stratify parameter in the train_test_split() function of scikit-learn. This is a dummy example with the same problem that appears randomly on my data:

from sklearn.model_selection import train_test_split
a = [1, 0, 0, 0, 0, 0, 0, 1]
train_test_split(a, stratify=a, random_state=42)

which returns:

[[1, 0, 0, 0, 0, 1], [0, 0]]

Shouldn't it select a "1" also in the test subset? From how I expect train_test_split() with stratify to work it should return something like:

[[1, 0, 0, 0, 0, 0], [0, 1]]

This happens with some values of random_state, while with other values it works correctly; but I cannot search for a "right" value of it every time I have to analyse data.

I have python 2.7 and scikit-learn 0.18.

Upvotes: 7

Views: 11005

Answers (1)

DanielP
DanielP

Reputation: 171

This question was asked 8 months ago but I guess an answer might still help readers in the future.

When using the stratify parameter, train_test_split actually relies on the StratifiedShuffleSplit function to do the split. As you see in the documentation, StratifiedShuffleSplit does aim to do the split by preserving the percentage of samples for each class, as you expected.

The problem is, in your example 25% (2 of 8 samples) are 1s, but the sample size is not large enough for you to see this proportion reflected on the test set. You have two options here:

A. Increase the size of the test set with the option test_size, which defaults to 0.25, to say 0.5. In this case, half of your samples will become your test set, and you'll see that 25% of them (i.e. 1 in 4) are 1.

>>> a = [1, 0, 0, 0, 0, 0, 0, 1]
>>> train_test_split(a, stratify=a, random_state=42, test_size=0.5)
[[1, 0, 0, 0], [0, 0, 1, 0]]

B. Keep test_size to its default value and increase the size of your set a so that 25% of its samples amount to at least 4 elements. An a of 16 samples or more will do that for you.

>>> a = [1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1]
>>> train_test_split(a, stratify=a, random_state=42)
[[0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0]]

Hope that helps.

Upvotes: 16

Related Questions