Corentin Moreau
Corentin Moreau

Reputation: 111

How can I get one column in my python (pandas) Dataframe to see all rules of my decision tree that led me to my result?

I'm working on decision tree (classifier) on sklearn, and it works well, I can visualize the tree, and predict my class. But I'd like to create one column (in my pandas dataframe) which is the path to get my result in the tree. I mean, I'd like a concatenation of all the rules to get my result like: - White=False,Black=False,Weight=1,price=5. Have you got any idea,please ?

Upvotes: 0

Views: 672

Answers (1)

Maximilian Peters
Maximilian Peters

Reputation: 31739

Based on the example here you can create your explanation of the applied rules.

  • estimator.decision_path gives you the nodes which are followed to get to the result
  • is_leaves is an array which stores for each node if it is a leaf, i.e. terminal, (True) or a branch/decision (False)
  • You can then iterate over node_indicator to get nodes which were visited
  • For each node you can get the threshold and the relevant feature
  • Finally apply the function to your dataframe and you are done.

    def get_decision_path(estimator, feature_names, sample, precision=2, is_leaves=None):
        if is_leaves is None:
            is_leaves = get_leaves(estimator)
        feature = estimator.tree_.feature
        threshold = estimator.tree_.threshold
    
        text = []
    
        node_indicator = estimator.decision_path([sample])
        node_index = node_indicator.indices[node_indicator.indptr[0]:
                                            node_indicator.indptr[1]]
    
        for node_id in node_index:
            if is_leaves[node_id]:
                break
    
            if sample[feature[node_id]] <= threshold[node_id]:
                threshold_sign = "<="
            else:
                threshold_sign = ">"
    
            text.append('{}: {} {} {}'.format(feature_names[feature[node_id]],
                                              sample[feature[node_id]],
                                              threshold_sign,
                                              round(threshold[node_id], precision)))
    
        return '; '.join(text)
    
    def get_leaves(estimator):
        n_nodes = estimator.tree_.node_count
        children_left = estimator.tree_.children_left
        children_right = estimator.tree_.children_right
        is_leaves = np.zeros(shape=n_nodes, dtype=bool)
        stack = [(0, -1)]
        while len(stack) > 0:
            node_id, parent_depth = stack.pop()
    
            if children_left[node_id] != children_right[node_id]:
                stack.append((children_left[node_id], parent_depth + 1))
                stack.append((children_right[node_id], parent_depth + 1))
            else:
                is_leaves[node_id] = True
        return is_leaves
    

Example

print(get_decision_path(estimator, 
                        ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 
                        [6.6, 3.0 , 4.4, 1.4]))

'petal width (cm): 1.4 > 0.8; petal length (cm): 4.4 <= 4.95; petal width (cm): 1.4 <= 1.65'

Full code

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
from sklearn import tree
import pydotplus
from IPython.core.display import HTML, display

def get_decision_path(estimator, feature_names, sample, precision=2, is_leaves=None):
    if is_leaves is None:
        is_leaves = get_leaves(estimator)
    feature = estimator.tree_.feature
    threshold = estimator.tree_.threshold

    text = []

    node_indicator = estimator.decision_path([sample])
    node_index = node_indicator.indices[node_indicator.indptr[0]:
                                        node_indicator.indptr[1]]

    for node_id in node_index:
        if is_leaves[node_id]:
            break

        if sample[feature[node_id]] <= threshold[node_id]:
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        text.append('{}: {} {} {}'.format(feature_names[feature[node_id]],
                                          sample[feature[node_id]],
                                          threshold_sign,
                                          round(threshold[node_id], precision)))

    return '; '.join(text)


def get_leaves(estimator):
    n_nodes = estimator.tree_.node_count
    children_left = estimator.tree_.children_left
    children_right = estimator.tree_.children_right
    is_leaves = np.zeros(shape=n_nodes, dtype=bool)
    stack = [(0, -1)]
    while len(stack) > 0:
        node_id, parent_depth = stack.pop()

        if children_left[node_id] != children_right[node_id]:
            stack.append((children_left[node_id], parent_depth + 1))
            stack.append((children_right[node_id], parent_depth + 1))
        else:
            is_leaves[node_id] = True
    return is_leaves

# prepare data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target

X = df.iloc[:, 0:4].to_numpy()
y = df.iloc[:, 4].to_numpy()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# create decision tree
estimator = DecisionTreeClassifier(max_leaf_nodes=5, random_state=0)
estimator.fit(X_train, y_train)

# visualize decision tree
dot_data = tree.export_graphviz(estimator, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
svg = graph.create_svg()
display(HTML(svg.decode('utf-8')))

# add explanation to data frame
is_leaves = get_leaves(estimator)
df['explanation'] = df.apply(lambda row: get_decision_path(estimator, df.columns[0:4], row[0:4], is_leaves=is_leaves), axis=1)

df.sample(5, axis=0, random_state=42)

enter image description here enter image description here

Upvotes: 1

Related Questions