Aravind Chamakura
Aravind Chamakura

Reputation: 339

Grouping data by sklearn.model_selection.GroupShuffleSplit

I have a dataset in a CSV with header as

PRODUCT_ID  CATEGORY_NAME   PRODUCT_TYPE DISPLAY_COLOR_NAME IMAGE_ID                        

with same product having multiple rows each with different image_id. I made Image Id as index col when reading CSV into pandas data frame.

I want to create test and train dataset by grouping the data at product_type or any other column. Also make sure same data is not repeated in both test and train dataset (since I have multiple lines for product with different images)

How can I achieve this using sklearn.model_selection.GroupShuffleSplit

Upvotes: 0

Views: 3926

Answers (1)

Aravind Chamakura
Aravind Chamakura

Reputation: 339

Here is the code I came up with, would like to see if this can be better done

import pandas as pd
import numpy as np
from sklearn.model_selection import GroupShuffleSplit

data = pd.DataFrame.from_csv('/myfileLocation/PRODUCTS.csv', index_col='IMAGE_ID')
data.reset_index()['PRODUCT_ID']
gss = GroupShuffleSplit(n_splits=1, test_size=0.3)
train_dataset,test_dataset = next(gss.split(X=data, y=data['PRODUCT_TYPE_NAME'], groups=data.index.values))

for x in np.nditer(test_dataset.T):
    rec = data.iloc[x]
    print (rec)

Upvotes: 1

Related Questions