Reputation: 161
Does Apache Spark provide an API to get a decision tree's prediction probability similar to scikit-learn's predict_proba function (i.e., decision_tree.predict_proba(X))?
Upvotes: 3
Views: 1266
Reputation: 1837
I was searching for this myself. I almost hacked up a solution when I noticed the api has the functionality in a very awkward way (at least for the LogisticRegressionModel
):
You clear the threshold (with clearThreshold()
). That way the predict function doesn't return the label, but the underlying value.
Java docs say this:
public LogisticRegressionModel clearThreshold() :: Experimental :: Clears the threshold so that predict will output raw prediction scores.
FYI: the returned values are between 0. and 1., the default threshold value is .5, so you can easily assess what you want to set your threshold to.
Upvotes: 1