Reputation: 186
I am trying to generate visualisation of decision tree. However, I am getting an error that I cannot resolve. This is my code:
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
import pydotplus
feature_cols = ['Reason_for_absence', 'Month_of_absence']
feature_cols
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names = feature_cols,class_names['0', '1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('tree.png')
Image(graph.create_png())
I am getting the following error:
File "", line 9
export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names = feature_cols,class_names['0', '1'])
^
SyntaxError: positional argument follows keyword argument
EDIT:
I have change the code according to the answer and now I am getting error:
IndexError: list index out of range
While the code was a amended a bit:
feature_cols = ['Reason_for_absence',
'Month_of_absence',
'Day_of_the_week',
'Seasons',
'Transportation_expense',
'Distance_from_Residence_to_Work',
'Service_time',
'Age',
'Work_load_Average/day ',
'Hit_target',
'Disciplinary_failure',
'Education',
'Son',
'Social_drinker',
'Social_smoker',
'Pet',
'Weight',
'Height',
'Bod_mass_index',
'Absenteeism']
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names = feature_cols, class_names=['0', '1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('tree.png')
Image(graph.create_png())
Upvotes: 1
Views: 303
Reputation: 6148
You were missing with a =
, you should update the last argument to class_names=['0', '1']
:
export_graphviz(clf, out_file=dot_data, filled=True, rounded=True,
special_characters=True,
feature_names = feature_cols,
class_names=['0', '1'])
Upvotes: 1