Reputation: 91
I want to use train_test_split to create train, validation and test set of my data. The "easy" way, according to other posts, is to run train_test_split twice. Which makes sense. But when I try that, it reports an error the second time it runs the split. (Sklearn version: 0.23.2) Am I missing something?
from sklearn.datasets import make_classification
df = make_classification()
X = df[0]
y = df[1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
print(X_train.shape, y_train.shape)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y)
Output: (80, 20) (80,)
The error it returns:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-74-bf895d511057> in <module>
6 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
7 print(X_train.shape, y_train.shape)
----> 8 X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y)
~\anaconda3\envs\trading\lib\site-packages\sklearn\model_selection\_split.py in train_test_split(*arrays, **options)
2150 random_state=random_state)
2151
-> 2152 train, test = next(cv.split(X=arrays[0], y=stratify))
2153
2154 return list(chain.from_iterable((_safe_indexing(a, train),
~\anaconda3\envs\trading\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups)
1338 to an integer.
1339 """
-> 1340 X, y, groups = indexable(X, y, groups)
1341 for train, test in self._iter_indices(X, y, groups):
1342 yield train, test
~\anaconda3\envs\trading\lib\site-packages\sklearn\utils\validation.py in indexable(*iterables)
290 """
291 result = [_make_indexable(X) for X in iterables]
--> 292 check_consistent_length(*result)
293 return result
294
~\anaconda3\envs\trading\lib\site-packages\sklearn\utils\validation.py in check_consistent_length(*arrays)
254 if len(uniques) > 1:
255 raise ValueError("Found input variables with inconsistent numbers of"
--> 256 " samples: %r" % [int(l) for l in lengths])
257
258
ValueError: Found input variables with inconsistent numbers of samples: [80, 100]
Upvotes: 0
Views: 1093
Reputation: 2042
The problem is in the stratify
argument. You are using stratify=y
, where you must use stratify=y_train
. If not, you reach the inconsistent number of samples error. Try the code below:
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train)
Upvotes: 2