Reputation: 607
I've been trying to analyze the DecisionTreeRegressor
I trained in sklearn
. I found http://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html useful in determining the attributes that split each branch in the tree, specifically this code snippet:
n_nodes = estimator.tree_.node_count
children_left = estimator.tree_.children_left
children_right = estimator.tree_.children_right
feature = estimator.tree_.feature
threshold = estimator.tree_.threshold
# The tree structure can be traversed to compute various properties such
# as the depth of each node and whether or not it is a leaf.
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)] # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
node_depth[node_id] = parent_depth + 1
# If we have a test node
if (children_left[node_id] != children_right[node_id]):
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
else:
is_leaves[node_id] = True
print("The binary tree structure has %s nodes and has "
"the following tree structure:"
% n_nodes)
for i in range(n_nodes):
if is_leaves[i]:
print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
else:
print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
"node %s."
% (node_depth[i] * "\t",
i,
children_left[i],
feature[i],
threshold[i],
children_right[i],
))
However, this doesn't tell me the value of each leaf node. If the above prints out something that looks like this:
The binary tree structure has 7 nodes and has the following tree structure:
node=0 test node: go to node 1 if X[:, 2] <= 1.00764083862 else to node 4.
node=1 test node: go to node 2 if X[:, 2] <= 0.974808812141 else to node 3.
node=2 leaf node.
node=3 leaf node.
node=4 test node: go to node 5 if X[:, 0] <= -2.90554761887 else to node 6.
node=5 leaf node.
node=6 leaf node.
How do I know the value that node 2 represents for example?
Upvotes: 1
Views: 3490
Reputation: 60319
The method you are looking for is estimator.tree_.value
Let's make a reproducible example, since the one you link to from the docs is for classification and not for regression:
import numpy as np
from sklearn.tree import DecisionTreeRegressor
# dummy data
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
estimator = DecisionTreeRegressor(max_depth=3)
estimator.fit(X, y)
After that, using your code verbatim, we get:
The binary tree structure has 15 nodes and has the following tree structure:
node=0 test node: go to node 1 if X[:, 0] <= 3.13275051117 else to node 8.
node=1 test node: go to node 2 if X[:, 0] <= 0.513901114464 else to node 5.
node=2 test node: go to node 3 if X[:, 0] <= 0.0460066311061 else to node 4.
node=3 leaf node.
node=4 leaf node.
node=5 test node: go to node 6 if X[:, 0] <= 2.02933192253 else to node 7.
node=6 leaf node.
node=7 leaf node.
node=8 test node: go to node 9 if X[:, 0] <= 3.85022854805 else to node 12.
node=9 test node: go to node 10 if X[:, 0] <= 3.42930102348 else to node 11.
node=10 leaf node.
node=11 leaf node.
node=12 test node: go to node 13 if X[:, 0] <= 4.68025827408 else to node 14.
node=13 leaf node.
node=14 leaf node.
Now, estimator.tree_.value
contains the values for all the tree nodes (here 15):
len(estimator.tree_.value)
# 15
and to get, for example, the value for node #3, we ask
estimator.tree_.value[3]
# array([[-1.1493464]])
For a detailed explanation of the value
contents (including non-terminal nodes), see my answers in
interpreting Graphviz output for decision tree regression (for regression) and
What does scikit-learn DecisionTreeClassifier.tree_.value do? (for classification).
Upvotes: 5