Mohsin Khalid
Mohsin Khalid

Reputation: 51

Stratified split in SQL

Consider I have a table like this

Customer Label
Sally 1
Molly 1
James 0
Kyle 0
Sara 0
Mathew 1
Kelly 1
Brad 1
Sam 0
Alex 0

How can I split them in such way that there are two tables i.e. Train and Test. Train has 80% of the rows and Test has 20% of the rows.

But the catch is it should be stratified with respect to 'Label' column. Both tables proportion of 1's and 0's after the split should be the same as it is in the original table which is 50%.

For e.g. Train table could be

Customer Label
Sally 1
Molly 1
Mathew 1
James 0
Kyle 0
Sara 0

and Test table could be

Customer Label
Kelly 1
Brad 1
Sam 0
Alex 0

If you observe both tables have 50-50 proportion of 1's and 0's as was in the original table.

Upvotes: 0

Views: 381

Answers (1)

lemon
lemon

Reputation: 15482

One option could be assigning a row number to each row by partitioning on the Label field. This would allow you to pair up rows with Label = 0 and Label = 1 that are smaller and larger than a threshold (e.g. all rows that have a row number lower than 3 should be in training set, the others in the test set). Making a union of rows having different Label field will allow to generalize on the size of the two parts (in case of imbalanced labels distribution).

Here's how you would get your training set:

WITH cte AS (
    SELECT *, ROW_NUMBER() OVER(PARTITION BY Label) AS rn
    FROM tab 
)
SELECT Customer, Label
FROM cte
WHERE rn <= (SELECT CEIL(COUNT(*)/2) FROM cte WHERE Label = 0) AND Label = 0
UNION
SELECT Customer, Label
FROM cte
WHERE rn <= (SELECT CEIL(COUNT(*)/2) FROM cte WHERE Label = 1) AND Label = 1

and your test set:

WITH cte AS (
    SELECT *, ROW_NUMBER() OVER(PARTITION BY Label) AS rn
    FROM tab 
)
SELECT Customer, Label
FROM cte
WHERE rn > (SELECT CEIL(COUNT(*)/2) FROM cte WHERE Label = 0) AND Label = 0
UNION
SELECT Customer, Label
FROM cte
WHERE rn > (SELECT CEIL(COUNT(*)/2) FROM cte WHERE Label = 1) AND Label = 1

If you wish you can merge the two tables in one, having a 'label' that can have value 'train' or 'test' to recognize which row belongs to which set:

WITH cte AS (
    SELECT *, ROW_NUMBER() OVER(PARTITION BY Label) AS rn
    FROM tab 
)
SELECT Customer, Label, 'train' AS split
FROM cte
WHERE rn <= (SELECT CEIL(COUNT(*)/2) FROM cte WHERE Label = 0) AND Label = 0
UNION
SELECT Customer, Label, 'train' AS split
FROM cte
WHERE rn <= (SELECT CEIL(COUNT(*)/2) FROM cte WHERE Label = 1) AND Label = 1
  
UNION

SELECT Customer, Label, 'test' AS split
FROM cte
WHERE rn > (SELECT CEIL(COUNT(*)/2) FROM cte WHERE Label = 0) AND Label = 0
UNION
SELECT Customer, Label, 'test' AS split
FROM cte
WHERE rn > (SELECT CEIL(COUNT(*)/2) FROM cte WHERE Label = 1) AND Label = 1

Try it here.

Note: This code will allow you to get your sets in the order you currently have. Though if you need to shuffle these values, you can always add an ORDER BY clause inside the ROW_NUMBER window function, like the following, using the RAND function or any ordering function of your choice:

ROW_NUMBER() OVER(PARTITION BY Label ORDER BY RAND()) AS rn

Note2: This code follows the MySQL dialect. If you're working with any other DBMS, please update your tag questions with the appropriate DBMS tag.

Note3: Your sample output shows training_set size equals to 60%, but you cited a 80% split in the description of your problem. Regardless of your split percentage, you can assign a value of your choice to the variable @train_size, that will be used inside the query.

Upvotes: 1

Related Questions