Reputation: 6135
I've just seen this answer on SO which shows how to split data using numpy.
Assume we're going to split them as 0.8
, 0.1
, 0.1
for training, testing, and validation respectively, you do it this way:
train, test, val = np.split(df, [int(.8 * len(df)), int(.9 * len(df))])
I'm interested to know how could I consider stratifying while splitting data using this methodology.
Stratifying is splitting data while keeping the priors of each class you have in data. That is if you're going to take
0.8
for the training set, you take 0.8 from each class you have. Same for test and train.
I tried grouping the data first by class using:
grouped_df = df.groupby(class_col_name, group_keys=False)
But it did not show correct results.
Note: I'm familiar with train_test_split
Upvotes: 3
Views: 4957
Reputation: 107747
Simply use your groupby object, grouped_df, which consists of each subsetted data frame where you can then run the needed np.split
. Then concatenate all sampled data frames with pd.concat
. Atogether, this would stratify according to your quoted message:
train_list = []; test_list = [], val_list = []
grouped_df = df.groupby(class_col_name)
# ITERATE THROUGH EACH SUBSET DF
for i, g in grouped_df:
# STRATIFY THE g (CLASS) DATA FRAME
train, test, val = np.split(g, [int(.8 * len(g)), int(.9 * len(g))])
train_list.append(train); test_list.append(test); val_list.append(val)
final_train = pd.concat(train_list)
final_test = pd.concat(test_list)
final_val = pd.concat(val_list)
Alternatively, a short-hand version using list comprehensions:
# LIST OF ARRAYS
arr_list = [np.split(g, [int(.8 * len(g)), int(.9 * len(g))]) for i, g in grouped_df]
final_train = pd.concat([t[0] for t in arr_list])
final_test = pd.concat([t[1] for t in arr_list])
final_val = pd.concat([v[2] for v in arr_list])
Upvotes: 3
Reputation: 2071
This assumes you have done stratification already such that a "category" column indicates which stratification each entry belongs to.
from collections import namedtuple
Dataset = namedtuple('Dataset', 'train test val')
grouped = df.groupby('headline')
splitted = {x: grouped.get_group(x).sample(frac=1) for x in grouped.groups}
datasets = {k:Dataset(*np.split(df, [int(.8 * len(df)), int(.9 * len(df))])) for k, df in splitted.items()}
This stores each stratified split by the category name assigned in df
.
Each item in datasets is a Dataset
namedtuple such that training, testing, and validation subsets are accessible by .train
, .test
, and .val
respectively.
Upvotes: 1