Capt_Bender
Capt_Bender

Reputation: 91

Sklearn train_test_split reporting error when running twice

I want to use train_test_split to create train, validation and test set of my data. The "easy" way, according to other posts, is to run train_test_split twice. Which makes sense. But when I try that, it reports an error the second time it runs the split. (Sklearn version: 0.23.2) Am I missing something?

from sklearn.datasets import make_classification
df = make_classification()
X = df[0]
y = df[1]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
print(X_train.shape, y_train.shape)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y)

Output: (80, 20) (80,)

The error it returns:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-74-bf895d511057> in <module>
      6 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
      7 print(X_train.shape, y_train.shape)
----> 8 X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y)

~\anaconda3\envs\trading\lib\site-packages\sklearn\model_selection\_split.py in train_test_split(*arrays, **options)
   2150                      random_state=random_state)
   2151 
-> 2152         train, test = next(cv.split(X=arrays[0], y=stratify))
   2153 
   2154     return list(chain.from_iterable((_safe_indexing(a, train),

~\anaconda3\envs\trading\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups)
   1338         to an integer.
   1339         """
-> 1340         X, y, groups = indexable(X, y, groups)
   1341         for train, test in self._iter_indices(X, y, groups):
   1342             yield train, test

~\anaconda3\envs\trading\lib\site-packages\sklearn\utils\validation.py in indexable(*iterables)
    290     """
    291     result = [_make_indexable(X) for X in iterables]
--> 292     check_consistent_length(*result)
    293     return result
    294 

~\anaconda3\envs\trading\lib\site-packages\sklearn\utils\validation.py in check_consistent_length(*arrays)
    254     if len(uniques) > 1:
    255         raise ValueError("Found input variables with inconsistent numbers of"
--> 256                          " samples: %r" % [int(l) for l in lengths])
    257 
    258 

ValueError: Found input variables with inconsistent numbers of samples: [80, 100]

Upvotes: 0

Views: 1093

Answers (1)

Alex Serra Marrugat
Alex Serra Marrugat

Reputation: 2042

The problem is in the stratify argument. You are using stratify=y, where you must use stratify=y_train. If not, you reach the inconsistent number of samples error. Try the code below:

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train)

Upvotes: 2

Related Questions