Carlos Pinzón
Carlos Pinzón

Reputation: 1427

How to visualize an sklearn GradientBoostingClassifier?

I've trained a gradient boost classifier, and I would like to visualize it using the graphviz_exporter tool shown here.

When I try it I get:

AttributeError: 'GradientBoostingClassifier' object has no attribute 'tree_'

this is because the graphviz_exporter is meant for decision trees, but I guess there's still a way to visualize it, since the gradient boost classifier must have an underlying decision tree.

How to do that?

Upvotes: 10

Views: 18857

Answers (2)

Dudelstein
Dudelstein

Reputation: 674

There is another nice visualization package called dtreeviz which I find really useful.

Using code from the existing answer:

from sklearn.ensemble import GradientBoostingClassifier
import numpy as np
from dtreeviz.trees import *

# Ficticuous data
np.random.seed(0)
X = np.random.normal(0,1,(1000, 3))
y = X[:,0]+X[:,1]*X[:,2] > 0

# Classifier
clf = GradientBoostingClassifier(max_depth=3, random_state=0)
clf.fit(X[:600], y[:600])

# Get the tree number 42
sub_tree_42 = clf.estimators_[42, 0]

# Visualization
viz = dtreeviz(sub_tree_42,
               x_data=X,
               y_data=y,
               target_name='Positive',
               feature_names=['X0', 'X1', 'X2'],
               class_names=['Negative', 'Positive'],
               title='Tree 42 visualization')

viz.save("tree_visualization.svg") 
viz.view()

enter image description here

Upvotes: 2

Carlos Pinzón
Carlos Pinzón

Reputation: 1427

The attribute estimators contains the underlying decision trees. The following code displays one of the trees of a trained GradientBoostingClassifier. Notice that although the ensemble is a classifier as a whole, each individual tree computes floating point values.

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
import numpy as np

# Ficticuous data
np.random.seed(0)
X = np.random.normal(0,1,(1000, 3))
y = X[:,0]+X[:,1]*X[:,2] > 0

# Classifier
clf = GradientBoostingClassifier(max_depth=3, random_state=0)
clf.fit(X[:600], y[:600])

# Get the tree number 42
sub_tree_42 = clf.estimators_[42, 0]

# Visualization
# Install graphviz: https://www.graphviz.org/download/
from pydotplus import graph_from_dot_data
from IPython.display import Image
dot_data = export_graphviz(
    sub_tree_42,
    out_file=None, filled=True, rounded=True,
    special_characters=True,
    proportion=False, impurity=False, # enable them if you want
)
graph = graph_from_dot_data(dot_data)
png = graph.create_png()
# Save (optional)
from pathlib import Path
Path('./out.png').write_bytes(png)
# Display
Image(png)

Tree number 42:

Code output (decision tree image)

Upvotes: 23

Related Questions