Reputation: 133
I have a specific technical question about sklearn, random forest classifier.
After fitting the data with the ".fit(X,y)" method, is there a way to extract the actual trees from the estimator object, in some common format, so the ".predict(X)" method can be implemented outside python?
Upvotes: 13
Views: 12045
Reputation: 40149
Yes, the trees of a forest are stored in the estimators_
attribute of
the forest object.
You can have a look at the implementation of the export_graphviz
function to learn out to write your custom exporter:
https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/tree/_export.py
Here is the usage doc for this function:
http://scikit-learn.org/stable/modules/tree.html#classification
Upvotes: 20
Reputation: 1074
Yes there is and @ogrisel answer enabled me to implement the following snippet, which enables to use a (partially trained) random forest to predict the values. It saves a lot of time if you want to cross validate a random forest model over the number of trees:
rf_model = RandomForestRegressor()
rf_model.fit(x, y)
estimators = rf_model.estimators_
def predict(w, i):
rf_model.estimators_ = estimators[0:i]
return rf_model.predict(x)
I explained this in more details here : extract trees from a Random Forest
Upvotes: 0