Reputation: 3029
I wish to plot the confusion matrix for my classification model. It has about 20000 documents that need to be classified to 90 classes. The confusion matrix I receive is huge. I wish to plot this but I only seem to find binary classification plots everywhere. Is it possible to plot this multi-class confusion matrix? I tried some methods but it does't display a clear one.
This is how my confusion matrix looks like:
[[3919 344 0 ..., 0 0 1]
[ 267 2739 0 ..., 0 0 0]
[ 1 6 17 ..., 0 0 0]
...,
[ 4 1 0 ..., 6 0 0]
[ 0 2 0 ..., 0 0 0]
[ 6 1 0 ..., 0 0 15]]
Upvotes: 4
Views: 8843
Reputation: 1017
This is my approach:
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
...
def generate_confusion_matrix(y_test, y_pred):
""" Generates a confusion_matrix plot based on the given values.
Args:
y_test (any): the resulting y_test of the function "train_test_split".
y_pred (any): the resulting value of the function "predict".
Returns:
_.
"""
logger = logging.getLogger('ThreatTrekker')
logger.debug('Plotting confusion matrix')
cm = confusion_matrix(y_test, y_pred)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# plot the confusion matrix using seaborn
sns.set(rc={'figure.figsize': (10, 6)}) # Size in inches
sns.heatmap(cm_norm, annot=True, cmap='Blues', fmt='.2f')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig(PLOTS_PATH + 'Confusion Matrix')
plt.show()
If you nees to make the matrix bigger, just use sns.set
function in order to make the plot bigger.
A matrix of 14 classes look like this with the previous shape:
Upvotes: 0
Reputation: 404
Another solution would be to plot only the classes with the highest number of samples.
top_n = 20
top_classes = [label[0] for label in Counter(y_true).most_common()[:top_n]] + ["_other"]
top_y_true = [y if y in top_classes else "_other" for y in y_true]
top_y_pred = [y if y in top_classes else "_other" for y in y_pred]
cm = confusion_matrix(top_y_true, top_y_pred)
Then you can plot it with the tools of your liking, e.g.:
fig, ax = plt.subplots(figsize=(20, 20))
sns_plot = sns.heatmap(cm, xticklabels=top_classes, yticklabels=top_classes)
sns_plot.set(title=conf.experiment.run.run_name)
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.show(block=False)
Upvotes: 2
Reputation: 141
Disclaimer,
Hi,
I think plotting a confusion matrix is not a good solution. I suggest you to save it as a html or csv file.
PyCM is a python module which can help you to show a multi-class confusion matrix through different types of reports such as a html report.
There is a simple code for saving a html report of a confusion matrix.
cm.save_html("file_name",color=(R,G,B))
Upvotes: 2
Reputation: 465
Here is some sample code using matplotlib (EDIT: added grid and switching off the interpolation)
import numpy as np
import matplotlib.pyplot as plt
confmat=np.random.rand(90,90)
ticks=np.linspace(0, 89,num=90)
plt.imshow(confmat, interpolation='none')
plt.colorbar()
plt.xticks(ticks,fontsize=6)
plt.yticks(ticks,fontsize=6)
plt.grid(True)
plt.show()
Upvotes: 4