son520804
son520804

Reputation: 483

How to annotate only the diagonal elements of a seaborn heatmap

I am using Seaborn heatmap to plot the output of a large confusion matrix. Since the diagonal element represents the correct prediction, they are more important to show the number/correct rate. As the question suggests, how to annotate only the diagonal entries in a heatmap?

I have consulted this website https://seaborn.pydata.org/examples/many_pairwise_correlations.html, but it does not help with how to annotate only the diagonal entries. Hope somebody could help with that. Thank you in advance!

Upvotes: 5

Views: 4485

Answers (2)

JohanC
JohanC

Reputation: 80339

In a related question, someone asked how to annotate the diagonal elements with strings. Here is an example:

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np

flights = sns.load_dataset('flights')
flights = flights.pivot('year', 'month', 'passengers')
corr_data = np.corrcoef(flights.to_numpy())

up_triang = np.triu(np.ones_like(corr_data)).astype(bool)
ax = sns.heatmap(corr_data, cmap='flare', xticklabels=False, yticklabels=False, square=True,
                 linecolor='white', linewidths=0.5,
                 cbar=True, mask=up_triang, cbar_kws={'shrink': 0.6, 'pad': 0.02, 'label': 'correlation'})
ax.invert_xaxis()
for i, label in enumerate(flights.index):
    ax.text(i + 0.2, i + 0.5, label, ha='right', va='center')
plt.show()

heatmap with annotated diagonal

Upvotes: 2

QuantStats
QuantStats

Reputation: 1486

Does this help you in getting what you have in mind? The URL example given by you does not have a diagonal, I had annotated the diagonal below the main diagonal instead. To annotate your confusion matrix diagonal, you can adapt to my code by changing the -1 value in np.diag(..., -1) to 0.

Note the additional parameter fmt='' that I had added in sns.heatmap(...) because my annot matrix elements are strings.

Code

from string import ascii_letters
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="white")

# Generate a large random dataset
rs = np.random.RandomState(33)
y = rs.normal(size=(100, 26))
d = pd.DataFrame(data=y, columns=list(ascii_letters[26:]))

# Compute the correlation matrix
corr = d.corr()

# Generate a mask for the upper triangle
mask = np.zeros_like(corr, dtype='bool')
mask[np.triu_indices_from(mask)] = True

# Set up the matplotlib figure
f, ax = plt.subplots(figsize=(11, 9))

# Generate a custom diverging colormap
cmap = sns.diverging_palette(220, 10, as_cmap=True)

# Generate the annotation
annot = np.diag(np.diag(corr.values,-1),-1)
annot = np.round(annot,2)
annot = annot.astype('str')
annot[annot=='0.0']=''

# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0,
            square=True, linewidths=.5, cbar_kws={"shrink": .5}, annot=annot, fmt='')

plt.show()

Output enter image description here

Upvotes: 6

Related Questions