Romain Jouin
Romain Jouin

Reputation: 4838

Pandas + Scikit learn : issue with stratified k-fold

When used with a Dataframe, StratifiedKFold from scikit-learn returns a list of indices from 0 to n instead of a list of values from the DF index. Is there a way to change that ?

Ex :

df = pd.DataFrame()
df["test"] = (0, 1, 2, 3, 4, 5, 6)
df.index   = ('a', 'b', 'c', 'd', 'e', 'f', 'g')
for i, (train, test) in enumerate(StratifiedKFold(df.index)):
    print i, (train, test)

Gives:

0 (array([], dtype=64), array([0,1,2,3,4,5,6])
1 (array([0,1,2,3,4,5,6]), array([], dtype=64))
2 (array([0,1,2,3,4,5,6]), array([], dtype=64))

I would expect the index from the df to be returned, and not the range of the length of the df...

Upvotes: 4

Views: 3580

Answers (1)

yangjie
yangjie

Reputation: 6725

The numbers you got are just indices of df.index selected by StratifiedKFold.

To change it back to the index of your DataFrame, simply

for i, (train, test) in enumerate(StratifiedKFold(df.index)):
    print i, (df.index[train], df.index[test])

which gives

0 (Index([], dtype='object'), Index([u'a', u'b', u'c', u'd', u'e', u'f', u'g'], dtype='object'))
1 (Index([u'a', u'b', u'c', u'd', u'e', u'f', u'g'], dtype='object'), Index([], dtype='object'))
2 (Index([u'a', u'b', u'c', u'd', u'e', u'f', u'g'], dtype='object'), Index([], dtype='object'))

Upvotes: 3

Related Questions