Reputation: 439
I am trying to use Kfold
cross valiadtion for my model, but get this error while doing so. I know that KFold
only accepts 1D arrays but even after converting the length input to an array its giving me this problem.
from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.cross_validation import KFold
if __name__ == "__main__":
np.random.seed(1335)
verbose = True
shuffle = False
n_folds = 5
y = np.array(y)
if shuffle:
idx = np.random.permutation(y.size)
X_train = X_train[idx]
y = y[idx]
skf = KFold(y, n_folds)
models = [RandomForestClassifier(n_estimators=100, n_jobs=-1, criterion='gini'),ExtraTreesClassifier(n_estimators=100, n_jobs=-1, criterion='entropy')]
print("Stacking in progress")
A = []
for j, clf in enumerate(models):
print(j, clf)
for i, (itrain, itest) in enumerate(skf):
print("Fold :", i)
x_train = X_train[itrain]
x_test = X_train[itest]
y_train = y[itrain]
y_test = y[itest]
print(x_train.shape, x_test.shape)
print(len(x_train), len(x_test))
clf.fit(x_train, y_train)
pred = clf.predict_proba(x_test)
A.append(pred)
I get the error for the line "skf = KFold(y, n_folds)
". Any help with this will be appreciated.
Upvotes: 0
Views: 453
Reputation: 13218
From its doc, KFold()
does not expect y
as an input, but only the number of splits (n_folds).
Once you have an instance of KFold
, you do myKfold.split(x)
(x
being all of your input data) to obtain an iterator yielding train and test indices. Example copy pasted from sklearn doc:
>>> from sklearn.model_selection import KFold
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = np.array([1, 2, 3, 4])
>>> kf = KFold(n_splits=2)
>>> kf.get_n_splits(X)
2
>>> print(kf)
KFold(n_splits=2, random_state=None, shuffle=False)
>>> for train_index, test_index in kf.split(X):
... print("TRAIN:", train_index, "TEST:", test_index)
... X_train, X_test = X[train_index], X[test_index]
... y_train, y_test = y[train_index], y[test_index]
TRAIN: [2 3] TEST: [0 1]
TRAIN: [0 1] TEST: [2 3]
Upvotes: 1