How to set the name of columns in the tree view through tree.plot_tree?

I'm trying to draw a tree through sklearn lib tree, but the problem is that column indexes are written in the picture.

tree.plot_tree(clf_decision)

enter image description here

Upvotes: 2

Views: 3386

Answers (1)

MaxU - stand with Ukraine
MaxU - stand with Ukraine

Reputation: 210842

make use of feature_names and class_names parameters:

from sklearn.datasets import load_iris
from sklearn import tree

iris = load_iris()
clf = tree.DecisionTreeClassifier(random_state=0).fit(iris.data, iris.target)


tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names)

enter image description here

Upvotes: 5

Related Questions