pepe
pepe

Reputation: 9909

Erratic behavior of train_test_split() in scikit-learn

Python 3.5 (anaconda install) SciKit 0.17.1

I just can't understand why train_test_split() has been giving me what I consider unreliable splits of a list of training cases.

Here's an example. My list trnImgPaths has 3 classes, each one with 67 images (total 201 images):

['/Caltech101/ferry/image_0001.jpg',
   ... thru ...
 '/Caltech101/ferry/image_0067.jpg',
 '/Caltech101/laptop/image_0001.jpg',
   ... thru ...
 '/Caltech101/laptop/image_0067.jpg',
 '/Caltech101/airplane/image_0001.jpg',
   ... thru ...
 '/Caltech101/airplane/image_0067.jpg']

My list of targets trnImgTargets perfectly matches this both in length and also the classes themselves align perfectly with trnImgPaths.

In[148]: len(trnImgPaths)
Out[148]: 201
In[149]: len(trnImgTargets)
Out[149]: 201

If I run:

[trnImgs, testImgs, trnTargets, testTargets] = \
    train_test_split(trnImgPaths, trnImgTargets, test_size=141, train_size=60, random_state=42)

or

[trnImgs, testImgs, trnTargets, testTargets] = \
    train_test_split(trnImgPaths, trnImgTargets, test_size=0.7, train_size=0.3, random_state=42)

or

[trnImgs, testImgs, trnTargets, testTargets] = \
    train_test_split(trnImgPaths, trnImgTargets, test_size=0.7, train_size=0.3)

Although I end up getting:

In[150]: len(trnImgs)
Out[150]: 60
In[151]: len(testImgs)
Out[151]: 141
In[152]: len(trnTargets)
Out[152]: 60
In[153]: len(testTargets)
Out[153]: 141

I never get a perfect split of 20 - 20 - 20 for the training set. I can tell because both by manual checking and doing a sanity check by confusion matrix. Here are the results for each experiment above, respectively:

[[19  0  0]
 [ 0 21  0]
 [ 0  0 20]]

[[19  0  0]
 [ 0 21  0]
 [ 0  0 20]]

[[16  0  0]
 [ 0 22  0]
 [ 0  0 22]]

I expected the split to be perfectly balanced. Any thoughts why this is happening?

It even appears it may be misclassifying a few cases a priori, because there will never be n=22 training cases for a given class.

Upvotes: 2

Views: 190

Answers (2)

pepe
pepe

Reputation: 9909

Based on @lejlot comments, the way I managed to lock in the number of cases was using a new feature for train_test_split on SKLearn 0.17. There is now an argument called stratify, which I'm using as follow (this will force the split to follow the number of labels in your label list):

[trnImgs, testImgs, trnTargets, testTargets] = \
    train_test_split(trnImgPaths, trnImgTargets, test_size=0.7,
                     train_size=0.3, stratify=trnImgTargets)

Now, every time I run the script I get:

[[20  0  0]
 [ 0 20  0]
 [ 0  0 20]]

Upvotes: 0

lejlot
lejlot

Reputation: 66805

In short: this is expected behaviour.

Random splitting does not guarantee "balanced" splits. This is what stratified splitting is for (also implemented in sklearn).

Upvotes: 1

Related Questions