thestruggleisreal
thestruggleisreal

Reputation: 1295

How to find a Class in the graphviz-graph of the Random Forest of scikit-learn?

I train a Random Forest Classifier with 10 estimators. Then, I save all the tree graphs with graphviz as dot and png files. Finally, I do RandomForest.predict.

From the output of the prediction, I picked one of the predicted classes and searched them in the graphs via Searching the dot-file (just searching the plain-text with STRG+F, that worked with another model). But I cannot find that class. When I look into the png-File, I just see one class in the nodes. (I cannot show the graphs here). That is weird because if there wouldn't be any node with different classes, it wouldn't predict them.

My goal is to trace the path how a data object gets its class predicted.

Here are the relevant parts of my code:

rfclf = RandomForestClassifier(class_weight = 'balanced')
rfclf.fit(x,y)

output:

RandomForestClassifier(bootstrap=True, class_weight='balanced', criterion='gini', max_depth=None, max_features='auto', max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=None, oob_score=False, random_state=None, verbose=0, warm_start=False)

estimator=rfclf.estimators_[8] #or [0],[1],[2],.....[9] because there are 10 estimators
# Export as dot-file
export_graphviz(estimator, out_file='Graphs/rfclf8.dot', 
                feature_names = x.columns,
                class_names = y,
                rounded = True, proportion = False, 
                precision = 2, filled = True)

# convert to PNG with system command (needs Graphviz)
from subprocess import call
call(['dot', '-Tpng', 'Graphs/rfclf8.dot', '-o', 'Graphs/rfclf8.png', '-Gdpi=600'])
#predict
rfclf.predict(dfP)

output: array(['-different classes-, dtype=object)

Is there something wrong in the code? It worked well for a different dataset.

Upvotes: 1

Views: 591

Answers (1)

Jon Nordby
Jon Nordby

Reputation: 6259

In order to trace the paths taken to classify a particular sample, you should use decision_path() of RandomForestClassifier. It is available since scikit-learn 0.18.0

Example code is available at https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html

Upvotes: 0

Related Questions