Reputation: 1427
I've trained a gradient boost classifier, and I would like to visualize it using the graphviz_exporter tool shown here.
When I try it I get:
AttributeError: 'GradientBoostingClassifier' object has no attribute 'tree_'
this is because the graphviz_exporter is meant for decision trees, but I guess there's still a way to visualize it, since the gradient boost classifier must have an underlying decision tree.
How to do that?
Upvotes: 10
Views: 18857
Reputation: 674
There is another nice visualization package called dtreeviz
which I find really useful.
Using code from the existing answer:
from sklearn.ensemble import GradientBoostingClassifier
import numpy as np
from dtreeviz.trees import *
# Ficticuous data
np.random.seed(0)
X = np.random.normal(0,1,(1000, 3))
y = X[:,0]+X[:,1]*X[:,2] > 0
# Classifier
clf = GradientBoostingClassifier(max_depth=3, random_state=0)
clf.fit(X[:600], y[:600])
# Get the tree number 42
sub_tree_42 = clf.estimators_[42, 0]
# Visualization
viz = dtreeviz(sub_tree_42,
x_data=X,
y_data=y,
target_name='Positive',
feature_names=['X0', 'X1', 'X2'],
class_names=['Negative', 'Positive'],
title='Tree 42 visualization')
viz.save("tree_visualization.svg")
viz.view()
Upvotes: 2
Reputation: 1427
The attribute estimators contains the underlying decision trees. The following code displays one of the trees of a trained GradientBoostingClassifier. Notice that although the ensemble is a classifier as a whole, each individual tree computes floating point values.
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
import numpy as np
# Ficticuous data
np.random.seed(0)
X = np.random.normal(0,1,(1000, 3))
y = X[:,0]+X[:,1]*X[:,2] > 0
# Classifier
clf = GradientBoostingClassifier(max_depth=3, random_state=0)
clf.fit(X[:600], y[:600])
# Get the tree number 42
sub_tree_42 = clf.estimators_[42, 0]
# Visualization
# Install graphviz: https://www.graphviz.org/download/
from pydotplus import graph_from_dot_data
from IPython.display import Image
dot_data = export_graphviz(
sub_tree_42,
out_file=None, filled=True, rounded=True,
special_characters=True,
proportion=False, impurity=False, # enable them if you want
)
graph = graph_from_dot_data(dot_data)
png = graph.create_png()
# Save (optional)
from pathlib import Path
Path('./out.png').write_bytes(png)
# Display
Image(png)
Tree number 42:
Upvotes: 23