Reputation: 53786
Below decision tree :
Is generated using code :
dt = DecisionTreeClassifier()
dt = clf.fit([[1],[2],[3]], [[3],[2],[3]])
dot_data = export_graphviz(dt, out_file=None,
feature_names=['1' , '2' , '3'],
class_names=['true' , 'false'],
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph
If I use instead :
dt = DecisionTreeClassifier()
dt = clf.fit([[1],[2],[3]], [[2],[3],[4]])
dot_data = export_graphviz(dt, out_file=None,
feature_names=['1' , '2', '3'],
class_names=['true' , 'false'],
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph
error is returned :
/usr/local/lib/python3.5/dist-packages/sklearn/tree/export.py in node_to_str(tree, node_id, criterion)
284 node_string += 'class = '
285 if class_names is not True:
--> 286 class_name = class_names[np.argmax(value)]
287 else:
288 class_name = "y%s%s%s" % (characters[1],
IndexError: list index out of range
Is this a quirk of the visualization as the classifier trains correctly ?
Upvotes: 0
Views: 10375
Reputation: 8801
I think there is one typo in your code. You are using clf.fit
whereas it should be dt.fit
.
Secondly, in the second instance you are specifying three labels, namely ['2','3','4']
whereas you are specifying the class labels as only true
and false
which is wrong, since you are clearly using more than 2 labels and it can't map an array of size 2 (class names) to an array of size 3(your actual labels). So basically you need to add another label besides true
and false
and it should work correctly.
dt = DecisionTreeClassifier()
dt = dt.fit([[1],[2],[3]], [[2],[3],[4]]) #It should be dt.fit not clf.fit
dot_data = export_graphviz(dt, out_file=None,
feature_names=['1' , '2', '3','4'],
class_names=['true' , 'false','something_else'],
filled=True, rounded=True,
special_characters=True)
Now it should work correctly. Feel free to name the third label as you like. The error basically occurred since you didn't specify anything for the third label, hence it could not map the actual labels (i.e. 2
,3
and 4
) to the ones you specified in the class_names, i.e. true
and false
.
Upvotes: 3