minks
minks

Reputation: 3029

Is it possible to plot a confusion matrix with 90 classes?

I wish to plot the confusion matrix for my classification model. It has about 20000 documents that need to be classified to 90 classes. The confusion matrix I receive is huge. I wish to plot this but I only seem to find binary classification plots everywhere. Is it possible to plot this multi-class confusion matrix? I tried some methods but it does't display a clear one.

This is how my confusion matrix looks like:

[[3919  344    0 ...,    0    0    1]
 [ 267 2739    0 ...,    0    0    0]
 [   1    6   17 ...,    0    0    0]
 ..., 
 [   4    1    0 ...,    6    0    0]
 [   0    2    0 ...,    0    0    0]
 [   6    1    0 ...,    0    0   15]]

Upvotes: 4

Views: 8843

Answers (4)

A.Casanova
A.Casanova

Reputation: 1017

This is my approach:

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

...

def generate_confusion_matrix(y_test, y_pred):
    """ Generates a confusion_matrix plot based on the given values.
    Args:
        y_test (any): the resulting y_test of the function "train_test_split".
        y_pred (any): the resulting value of the function "predict".
    Returns:
        _.
    """
    logger = logging.getLogger('ThreatTrekker')
    logger.debug('Plotting confusion matrix')

    cm = confusion_matrix(y_test, y_pred)
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    # plot the confusion matrix using seaborn
    sns.set(rc={'figure.figsize': (10, 6)})  # Size in inches
    sns.heatmap(cm_norm, annot=True, cmap='Blues', fmt='.2f')

    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig(PLOTS_PATH + 'Confusion Matrix')
    plt.show()

If you nees to make the matrix bigger, just use sns.set function in order to make the plot bigger.

A matrix of 14 classes look like this with the previous shape: enter image description here

Upvotes: 0

Sviat Lavrinchuk
Sviat Lavrinchuk

Reputation: 404

Another solution would be to plot only the classes with the highest number of samples.

top_n = 20
top_classes = [label[0] for label in Counter(y_true).most_common()[:top_n]] + ["_other"]
top_y_true = [y if y in top_classes else "_other" for y in y_true]
top_y_pred = [y if y in top_classes else "_other" for y in y_pred]
cm = confusion_matrix(top_y_true, top_y_pred)

Then you can plot it with the tools of your liking, e.g.:

fig, ax = plt.subplots(figsize=(20, 20))
sns_plot = sns.heatmap(cm, xticklabels=top_classes, yticklabels=top_classes)
sns_plot.set(title=conf.experiment.run.run_name)
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.show(block=False)

Upvotes: 2

Alireza Zolanvari
Alireza Zolanvari

Reputation: 141

Disclaimer,

Hi,

I think plotting a confusion matrix is not a good solution. I suggest you to save it as a html or csv file.

PyCM is a python module which can help you to show a multi-class confusion matrix through different types of reports such as a html report.

There is a simple code for saving a html report of a confusion matrix.

cm.save_html("file_name",color=(R,G,B))

Upvotes: 2

shiftyscales
shiftyscales

Reputation: 465

Here is some sample code using matplotlib (EDIT: added grid and switching off the interpolation)

import numpy as np
import matplotlib.pyplot as plt

confmat=np.random.rand(90,90)
ticks=np.linspace(0, 89,num=90)
plt.imshow(confmat, interpolation='none')
plt.colorbar()
plt.xticks(ticks,fontsize=6)
plt.yticks(ticks,fontsize=6)
plt.grid(True)
plt.show()

enter image description here

Upvotes: 4

Related Questions