Reputation: 2365
My code-
conf_matrix_list_of_arrays = []
kf = KFold(n_splits=10)
for i, (train_index, test_index) in enumerate(kf.split(X, y)):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
model.fit(X_train, y_train)
conf_matrix = confusion_matrix(y_test,model.predict(X_test))
conf_matrix_list_of_arrays .append(conf_matrix)
cf_matrix = np.sum(conf_matrix_list_of_arrays, axis=0)
###Visualization of confusion matrix
group_names = ['True Neg','False Pos','False Neg','True Pos']
group_counts = ["{0:0.0f}".format(value) for value in
cf_matrix.flatten()]
group_percentages = ["{0:.2%}".format(value) for value in
cf_matrix.flatten()/np.sum(cf_matrix)]
labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in
zip(group_names,group_counts,group_percentages)]
labels = np.asarray(labels).reshape(2,2)
sns_plot=sns.heatmap(cf_matrix, annot=labels, fmt='', cmap='Blues').set_title('Confusion Matrix '+model_name)
sns_plot.yaxis.set_ticklabels(np.unique(y_pred).tolist()[::-1])
plt.show()
figure = sns_plot.get_figure()
figure.savefig(result_path+'Confusion Matrix '+model_name+'.png', dpi=400)
Error I got-
AttributeError: 'Text' object has no attribute 'yaxis' on this line sns_plot.yaxis.set_ticklabels(np.unique(y_pred).tolist()[::-1])
My confusion matrix has 0 and 1 as y label and I wanted to change that to a string. np.unique(y_pred).tolist()[::-1]
is the list ['agreed','disagreed']
which I want to use as y label.
edit--
Upvotes: 0
Views: 2283
Reputation: 4275
This line
sns.heatmap(cf_matrix, annot=labels, fmt='', cmap='Blues').set_title('Confusion Matrix '+model_name)
in your code (in the assignment to sns_plot variable) doesn't create a matplotlib axis object, but rather a matplotlib.text.Text object, which doesn't have a yaxis attribute.
You could assign the sns.heatmap() call to sns_plot, then set_title in a separate line as:
sns_plot=sns.heatmap(cf_matrix, annot=labels, fmt='', cmap='Blues')
sns_plot.set_title('Confusion Matrix '+model_name)
sns_plot.yaxis.set_ticklabels(np.unique(y_pred).tolist()[::-1])
But also be aware, that you can just pass xticklabels or yticklabels directly to sns.heatmap() function as:
sns_plot=sns.heatmap(cf_matrix, annot=labels, fmt='', cmap='Blues', yticklabels = np.unique(y_pred).tolist()[::-1]).set_title('Confusion Matrix '+model_name)
Upvotes: 2