gha7all
gha7all

Reputation: 65

Seaborn.heatmap doesn't adjust colors based on specific value

I wanted a heatmap of confusion matrix which generates colors based on percentages of each class. For example the highest percentages goes black and the others get lighter based on their percentages(higher percentages = darker color). I tried changing vmin and vmax, but colors get changed based on 'counts' value not grouped_percentages.

categories = ['a', 'b', 'c']
group_percentages = []
counts = []
for i in range (len(cf)):
  for j in range(len(cf)):
    group_percentages.append(cf[j,i]/np.sum(cf[:,i]))
    counts.append(cf[j,i])

group_percentages = ['{0:.2%}'.format(value) for value in
                group_percentages]

counts = ['{0:0.0f}'.format(value) for value in
            counts]
labels = [f'{v1}\n{v2}' for v1, v2 in zip(group_percentages, counts)]
labels = np.asarray(labels).reshape(3,3,order='F')

sns.heatmap(cf, annot=labels, fmt='', xticklabels=categories, yticklabels=categories, cmap='Greys', vmax=100, cbar=False)

output:

enter image description here

As you can see eventhough I set vmax to 100, cf[0,0] is 100%, but the color in the heatmap on gray, but cf[1,1] is 89% and its color is black.

Upvotes: 2

Views: 655

Answers (1)

Zephyr
Zephyr

Reputation: 12496

You should use group_percentages as data for the heatmap, but first you need to reshape this list in a 3x3 matrix:

percentages_matrix = np.reshape(group_percentages, (3, 3))

Complete Code

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

cf = np.array([[23, 0, 3],
               [0, 106, 5],
               [0, 12, 76]])

categories = ['a', 'b', 'c']
group_percentages = []
counts = []
for i in range(len(cf)):
    for j in range(len(cf)):
        group_percentages.append(cf[j, i]/np.sum(cf[:, i]))
        counts.append(cf[j, i])

percentages_matrix = np.reshape(group_percentages, (3, 3))
group_percentages = ['{0:.2%}'.format(value) for value in group_percentages]

labels = [f'{v1}\n{v2}' for v1, v2 in zip(group_percentages, counts)]
labels = np.asarray(labels).reshape(3, 3, order = 'F')

sns.heatmap(percentages_matrix, annot = labels, fmt = '', xticklabels = categories, yticklabels = categories, cmap = 'Greys', vmax = 1, vmin = 0, cbar = False)

plt.show()

enter image description here

Upvotes: 1

Related Questions