Reputation: 5380
Given a machine learning model built on top of scikit-learn, how can I classify new instances but then choose only those with the highest confidence? How do we define confidence in machine learning and how to generate it (if not generated automatically by scikit-learn)? What should I change in this approach if I had more that 2 potential classes?
This is what I have done so far:
# load libraries
from sklearn import neighbors
# initialize NearestNeighbor classifier
knn = neighbors.KNeighborsClassifier(n_neighbors=3)
# train model
knn.fit([[1],[2],[3],[4],[5],[6]], [0,0,0,1,1,1])
# predict ::: get class probabilities
print(knn.predict_proba(1.5))
print(knn.predict_proba(37))
print(knn.predict_proba(3.5))
Example:
Let's assume that we have created a model using the XYZ machine learning algorithm. Let's also assume that we are trying to classify users based on their gender using information such as location, hobbies, and income. Then, we have 10 say new instances that we want to classify. As normal, upon the applying of the model, we get 10 outputs, either M (for male) or F (for female). So far so good. However, I would like to somehow measure the precision of these results and then, by using a hard-coded threshold, leave out those with low precision. My question is on how to measure the precession. Is probability (as given by the predict_proba() function) a good measure? For example, can I say that if probably is between 0.9 and 1 then "keep" (otherwise "omit")? Or I should use a more sophisticated method for doing that? As you can see, I lack theoretical background so any help would be highly appreciated.
Upvotes: 2
Views: 2681
Reputation: 848
While this is more of a stats question I can give answers relative to scikit-learn.
Confidence in machine learning depends on the method used for the model. For exemple with 3-NN (what you used), predict_proba(x) will give you n/3 with x the number of "class 1" among the 3 nearest neighbours from x. You can easily say that if n/3 is smaller than 0.5 that means there are less than 2 "class 1" among the nearest neighbours and that there are more than 2 "class 0". That means your x is more likely to be from "class 0". (I assume you knew that already)
For another method like SVM the confidence can be the distance from the point considered to the hyperplan or for ensemble models it could be the number of aggregated votes towards a certain class. Scikit-learn's predict_proba() uses what is available from the model.
For multiclass problems (imagine Y can be equal to A, B or C) ypu have two main approach that are sometimes directly taken into consideration in scikit learn.
The first approach is OneVsOne. It basically compute every new sample as a AvsB AvsC and BvsC model and takes the most probable (imagine if A wins against B and against C it is very likely that the right class is A, the annoying cases are resolved by taking the class that has the highest confidence in the match ups e.g. if A wins against B, B wins against C and C wins against C, if the confidence of A winning against B is higher than the rest it will most likely be A).
The second approach is OneVsAll, in wich you compute A vs B and C, B vs A and C, C vs A and B and take the class that is the most likely by looking at the confidence scores.
Using scikit-learn's predict() will always give the most likely class based on the confidence scores that predict_proba would give.
I suggest you read this http://scikit-learn.org/stable/modules/multiclass.html very carefully.
EDIT :
Ah I see what you are trying to do. predict_proba() has a big flaw : let's assume you have a big outlier in your new instances (e.g. female with video games and guns as hobbies, software developper as a job etc.) if you use for instance k-NN and your outlier will be in the flock of the other classe's cloud of point predict_proba() could give 1 as a confidence score for Male while the instance is Female. However it will well for undecisive cases (e.g. male or female, with video games and guns as hobbies, and works in a nursery) as predict_proba() will give something around ~0.5.
I don't know if something better can be used tought. If you have enough training samples for doing cross validation I suggest you maybe look toward ROC and PR curves for optimizing your threshold.
Upvotes: 3