user3095701
user3095701

Reputation: 133

Extracting the trees (predictor) from random forest classifier

I have a specific technical question about sklearn, random forest classifier.

After fitting the data with the ".fit(X,y)" method, is there a way to extract the actual trees from the estimator object, in some common format, so the ".predict(X)" method can be implemented outside python?

Upvotes: 13

Views: 12045

Answers (2)

ogrisel
ogrisel

Reputation: 40149

Yes, the trees of a forest are stored in the estimators_ attribute of the forest object.

You can have a look at the implementation of the export_graphviz function to learn out to write your custom exporter:

https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/tree/_export.py

Here is the usage doc for this function:

http://scikit-learn.org/stable/modules/tree.html#classification

Upvotes: 20

RUser4512
RUser4512

Reputation: 1074

Yes there is and @ogrisel answer enabled me to implement the following snippet, which enables to use a (partially trained) random forest to predict the values. It saves a lot of time if you want to cross validate a random forest model over the number of trees:

rf_model = RandomForestRegressor()
rf_model.fit(x, y)

estimators = rf_model.estimators_

def predict(w, i):
    rf_model.estimators_ = estimators[0:i]
    return rf_model.predict(x)

I explained this in more details here : extract trees from a Random Forest

Upvotes: 0

Related Questions