Reputation: 689
I'm quite new to machine learning techniques, and I'm having trouble following some of the scikit-learn documentation and other stackoverflow posts.. I'm trying to create a simple model from a bunch of medical data that will help me predict which of three classes a patient could fall into.
I load the data via pandas, convert all the objects to integers (Male = 0, Female=1 for example), and run the following code:
import numpy as np
import pandas as pd
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.ensemble import ExtraTreesClassifier
# Upload data file with all integers:
data = pd.read_csv('datafile.csv')
y = data["Target"]
features = list(data.columns[:-1]) # Last column being the target data
x = data[features]
ydata = label_binarize(y, classes=[0, 1, 2])
n_classes = ydata.shape[1]
X_train, X_test, y_train, y_test = train_test_split(x, ydata, test_size=.5)
model2 = ExtraTreesClassifier()
model2.fit(X_train, y_train)
out = model2.predict(X_test)
print np.min(out),np.max(out)
The predicted values of out
range between 0.0 and 1.0, but the classes I am trying to predict are 0,1, and 2. What am I missing?
Upvotes: 0
Views: 1346
Reputation: 33542
That's normal behaviour in scikit-learn.
There are two approaches possible:
y=[n_samples, ] -> y[n_samples, n_classes]
(1 dimension added; integers in range(0, X) get transformed to binary values)classifier.predict()
will also return results of the form [n_predict_samples, n_classes]
(with 0 and 1 as the only values) / That's what you observe![[0 0 0 1], [1 0 0 0], [0 1 0 0]]
= predictions for class: 3, 0, 1y=[n_samples, ]
classifier.predict()
will also return results of the form [n_predict_samples, ]
(with possibly other values than 0, 1)[3 0 1]
Both outputs are mentioned in the docs here:
predict(X)
Returns:
y : array of shape = [n_samples] or [n_samples, n_outputs]
The predicted classes.
Remark: the above behaviour should be valid for most/all classifiers! (not only ExtraTreesClassifier)
Upvotes: 1