Wonton
Wonton

Reputation: 339

Scikit-Learn GroupShuffleSplit is not grouping by specified groups

I am trying to split a timeseries of farm data taken at a daily frequency for 8 years. I want to split the data so that the train and test sets each contain samples from different farms, and there is no overlap of farms between the train and test sets. I have created a column in the dataframe containing the unique FarmID indicating which farm the sample came from.

Visually, is what the dataset looks like in general:

df

╔════════╦════════════╦═══════════╦═════╦═══════════╗
║ FarmID ║  datetime  ║ Feature_1 ║ ... ║ Feature_n ║
╠════════╬════════════╬═══════════╬═════╬═══════════╣
║ 0      ║ 2009-01-01 ║ 45.76     ║ ... ║ 15.12     ║
║ ...    ║ ...        ║ ...       ║ ... ║ ...       ║
║ 3668   ║ 2017-12-31 ║ 12.12     ║ ... ║ 15.75     ║
╚════════╩════════════╩═══════════╩═════╩═══════════╝
6702142 rows × 35 columns


df[df.FarmID==0]

╔════════╦════════════╦═══════════╦═════╦═══════════╗
║ FarmID ║  datetime  ║ Feature_1 ║ ... ║ Feature_n ║
╠════════╬════════════╬═══════════╬═════╬═══════════╣
║ 0      ║ 2009-01-01 ║ 35.31     ║ ... ║ 67.41     ║
║ ...    ║ ...        ║ ...       ║ ... ║ ...       ║
║ 0      ║ 2017-12-31 ║ 2.15      ║ ... ║ 5.21      ║
╚════════╩════════════╩═══════════╩═════╩═══════════╝
1096 rows x 35 columns


# Note: Not all farms contain the same number of samples as some farms didn't submit data in some years.

To split the dataset, this is the code I have used:

df = df.sort_values('FarmID')

def group_split(df, test_size=.80, seed=seed):
    from sklearn.model_selection import GroupShuffleSplit
    gss = GroupShuffleSplit(1, test_size, random_state=seed)

    for test_indices, train_indices in gss.split(df, groups=df.FarmID):
        train = df.loc[train_indices]
        test = df.loc[test_indices]

    return train, test

train, test = group_split(df)

Upon inspecting the unique farms contained in the train-test splits, I see that there are some farms contained in both the train and test set.

In: train.FarmID.unique()

Out: array([2.000e+00, 4.000e+00, 8.000e+00, ..., 2.245e+03, 2.229e+03,
            2.575e+03])


In: test.FarmID.unique()

Out: array([0.000e+00, 1.000e+00, 1.300e+01, ..., 2.245e+03, 2.229e+03,
            2.575e+03])


In: n = 2245
    df[df.FarmID==n].shape
    train[train.FarmID==n].shape
    test[test.FarmID==n].shape

Out: (1826, 35)
     (1225, 35)
     (601, 35)

However, there are some farms which are split correctly.

In: n = 3668
    df[df.FarmID==n].shape
    train[train.FarmID==n].shape
    test[test.FarmID==n].shape

Out: (705, 35)
     (705, 35)
     (0, 35)

Furthermore, 995 of the 3669 farms are overlapping in the train-test sets.

In: train_FarmIDs = train.FarmID.unique()
    test_FarmIDs = test.FarmID.unique()
    len(set(train_FarmIDs).intersection(set(test_FarmIDs)))

Out: 995

I'm absolutely stumped as to why sklearn's GroupShuffleSplit isn't splitting by the groups I specified correctly. I would really appreciate if someone can help me with this issue!

Upvotes: 3

Views: 2022

Answers (1)

Yuval Nezri
Yuval Nezri

Reputation: 26

Only a guess, but i think gss is converting your dataframe to an ndarray, and returns the positional indices of the ndarray. You sort the df, which scrambles your df index, and then use .loc[]. Try using .iloc[] instead, or convert your df to a numpy array before using gss, and then slice over the numpy array and not the dataframe.

Upvotes: 1

Related Questions