user9580899
user9580899

Reputation: 187

How sklearn DecsionTreeClassifier choose output values when the max_depth has given 1?

This is my code

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

dataset = load_iris()
X_train,X_test,y_train,y_test = train_test_split(dataset.data,dataset.target,test_size=0.3)


reg = DecisionTreeClassifier(max_depth=1)
reg.fit(X_train,y_train)
print(reg.predict(X_test))

enter image description here

I have added the image of tree for trained set, here you can see on false case the dataset has values of [0,39,38] which points to the output of 0,1,2 respectively. So from false dataset 1 has highest possibility to become an output. Decision tree should classify either 0 or 1 as per the tree but I can see 2 also in the prediction. So, How sklearn choose the class on false set under what condition to predict the output.

Upvotes: 1

Views: 148

Answers (1)

Venkatachalam
Venkatachalam

Reputation: 16966

I am sure, the difference would have been because of not setting the random_state.

There is two places for randomness here,

  • train test splitting
  • Building decision tree model

you might have predicted with a decision tree and then created a visualization using another decision tree.

Try the following code with different random_state values:

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import plot_tree

dataset = load_iris()

X_train,X_test,y_train,y_test = train_test_split(dataset.data,
                                                 dataset.target,
                                                 test_size=0.3,
                                                 random_state=0)
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(max_depth=1, random_state=1)
clf.fit(X_train,y_train)
print(clf.predict(X_test))

plot_tree(clf)

enter image description here

Note: you need sklearn version 0.21.2 for plot_tree feature.

Upvotes: 1

Related Questions