user1700890
user1700890

Reputation: 7742

Leaf ordering in scikit-learn

I am constructing decision tree in scikit-learn and tree is missing leaf #2. I wonder why? Here is my example:

import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_graphviz

def leaf_ordering():
    X = np.genfromtxt('X.csv', delimiter=',')
    Y = np.genfromtxt('Y.csv',delimiter=',')
    dt = DecisionTreeClassifier(min_samples_leaf=100, random_state=99)
    dt.fit(X, Y)
    print(set(dt.apply(X)))

leaf_ordering()

link to file X link to file Y

Here is output: {1, 3, 4}. As you can see there is no leaf #2.

Upvotes: 0

Views: 187

Answers (1)

Randy
Randy

Reputation: 14857

Nodes 0 and 2 in your example are both non-leaf nodes. In my example below, you can see from the export that 0, 1, and 4 are all internal tree nodes, and 2, 3, 5, and 6 are the leaves, and so all the predictions are going to be in one of those 4.

In [35]: X = np.random.random([100, 5])

In [36]: y = X.sum(axis=1) + np.random.random(100)

In [37]: dt = DecisionTreeRegressor(max_depth=2)

In [38]: dt.fit(X, y)
Out[38]:
DecisionTreeRegressor(criterion='mse', max_depth=2, max_features=None,
           max_leaf_nodes=None, min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, presort=False, random_state=None,
           splitter='best')

In [39]: dt.apply(X)
Out[39]:
array([6, 3, 3, 3, 6, 6, 3, 6, 3, 6, 2, 3, 3, 5, 3, 5, 5, 6, 3, 3, 3, 3, 3,
       3, 3, 6, 6, 3, 3, 3, 3, 5, 3, 5, 3, 3, 3, 3, 2, 3, 3, 3, 6, 3, 3, 3,
       3, 6, 3, 5, 2, 3, 3, 6, 3, 3, 3, 3, 3, 6, 6, 3, 6, 6, 3, 5, 6, 3, 3,
       3, 3, 6, 3, 3, 2, 3, 6, 2, 6, 2, 3, 3, 6, 2, 5, 6, 3, 3, 3, 6, 5, 3,
       3, 3, 6, 6, 3, 3, 6, 5])

In [40]: export_graphviz(dt)

In [41]: !cat tree.dot
digraph Tree {
node [shape=box] ;
0 [label="X[2] <= 0.7003\nmse = 0.4442\nsamples = 100\nvalue = 3.0586"] ;
1 [label="X[4] <= 0.1842\nmse = 0.3332\nsamples = 65\nvalue = 2.8321"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="mse = 0.0426\nsamples = 7\nvalue = 1.9334"] ;
1 -> 2 ;
3 [label="mse = 0.2591\nsamples = 58\nvalue = 2.9406"] ;
1 -> 3 ;
4 [label="X[0] <= 0.3576\nmse = 0.3782\nsamples = 35\nvalue = 3.4791"] ;
0 -> 4 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
5 [label="mse = 0.1212\nsamples = 10\nvalue = 2.9395"] ;
4 -> 5 ;
6 [label="mse = 0.3179\nsamples = 25\nvalue = 3.695"] ;
4 -> 6 ;
}

Upvotes: 1

Related Questions