sokeefe1014
sokeefe1014

Reputation: 237

Get feature and class names into decision tree using export graphviz

Good Afternoon,

I am working on a decision tree classifier and am having trouble visualizing it. I can output the decision tree, however I cannot get my feature or class names/labels into it. My data is in a pandas dataframe format which I then move into a numpy array and pass to the classifier. I've tried a few things, but just seem to error out on the export when I try and specify class names. Any help would be appreciated. Code is below.

all_inputs=df.ix[:,14:].values
all_classes=df['wic'].values

(training_inputs,
 testing_inputs,
 training_classes,
 testing_classes) = train_test_split(all_inputs, all_classes,train_size=0.75, random_state=1)

decision_tree_classifier=DecisionTreeClassifier()
decision_tree_classifier.fit(training_inputs,training_classes)

export_graphviz(decision_tree_classifier, out_file="mytree.dot",  
                     feature_names=??,  
                     class_names=??)  

LIke I said, it runs fine and outputs a decision tree viz if I take out the feature_names and class_names parameters. I'd like to include them in the output though if possible and have hit a wall...

Any help would be greatly appreciated!

Thanks,

Scott

Upvotes: 13

Views: 26121

Answers (3)

Hutch
Hutch

Reputation: 173

#for class names
y_class_names=list(df['your_class_type_column'].unique())

#for feature names
X_col_names = list(X_train.columns)
feature_names = X_col_names
or
X_names = list(df.columns[:-1])

eg:  
import graphviz
# DOT data
dot_data = tree.export_graphviz(clf_gini, out_file=None, 
                                feature_names=X_col_names,  
                                class_names=y_class_names,
                                filled=True)

# Draw graph
graph = graphviz.Source(dot_data, format="png") 
graph

Upvotes: 0

maxymoo
maxymoo

Reputation: 36545

The class names are stored in decision_tree_classifier.classes_, i.e. the classes_ attribute of your DecisionTreeClassifier instance. And the feature names should be the columns of your input dataframe. For your case you will have

class_names = decision_tree_classifier.classes_
feature_names = df.columns[14:]

Upvotes: 21

D Petrova
D Petrova

Reputation: 91

Personally for me class_names = True worked. It would show the symbolic representation of the outcome.

feature_names = df.columns[14:]
tree.export_graphviz(decision_tree_classifier, out_file="mytree.dot",  
                     feature_names=feature_names ,  
                     class_names=TRUE)  

Here is some more details on the topic: https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

Upvotes: 8

Related Questions