kamilazdybal
kamilazdybal

Reputation: 323

Highlighting maximum value in a column on a seaborn heatmap

I have a seaborn.heatmap plotted from a DataFrame:

enter image description here

import seaborn as sns
import matplotlib.pyplot as plt

fig = plt.figure(facecolor='w', edgecolor='k')
sns.heatmap(collected_data_frame, annot=True, vmax=1.0, cmap='Blues', cbar=False, fmt='.4g')

I would like to create some sort of highlight for a maximum value in each column - it could be a red box around that value, or a red dot plotted next to that value, or the cell could be colored red instead of using Blues. Ideally I'm expecting something like this:

enter image description here

I got the highlight working for DataFrame printing in Jupyter Notebook using tips from this answer:

enter image description here

How can I achieve a similar thing but on a heatmap?

Upvotes: 2

Views: 5966

Answers (2)

kamilazdybal
kamilazdybal

Reputation: 323

Complete solution based on the answer of @r-beginners:

Generate DataFrame:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn

arr = np.array([[0.9336719 , 0.90119269, 0.90791181, 0.3112451 , 0.56715989,
        0.83339874, 0.14571595, 0.36505745, 0.89847367, 0.95317909,
        0.16396293, 0.63463356],
       [0.93282304, 0.90605976, 0.91276066, 0.30288519, 0.56366228,
        0.83032344, 0.14633036, 0.36081791, 0.9041638 , 0.95268572,
        0.16803188, 0.63459491],
       [0.15215358, 0.4311569 , 0.32324376, 0.51620611, 0.69872915,
        0.08811177, 0.80087247, 0.234593  , 0.47973905, 0.21688613,
        0.2738223 , 0.38322856],
       [0.90406056, 0.89632902, 0.92220635, 0.3022458 , 0.58843012,
        0.78159595, 0.17089609, 0.33443782, 0.89997103, 0.93128579,
        0.15942313, 0.62644379],
       [0.93868063, 0.45617598, 0.17708323, 0.81828266, 0.72986428,
        0.82543775, 0.41530088, 0.2604382 , 0.33132295, 0.94686745,
        0.05607774, 0.54141198]])

columns_text = [str(num) for num in range(0,12)]
index_text = ['C1', 'C2', 'C3', 'C4', 'C5']
arr_data_frame = pd.DataFrame(arr, columns=columns_text, index=index_text)

Highlighting maximum in a column:

fig,ax = plt.subplots(figsize=(15, 3), facecolor='w', edgecolor='k')
ax = seaborn.heatmap(arr_data_frame, annot=True, vmax=1.0, vmin=0, cmap='Blues', cbar=False, fmt='.4g', ax=ax)

column_max = arr_data_frame.idxmax(axis=0)

for col, variable in enumerate(columns_text):
    position = arr_data_frame.index.get_loc(column_max[variable])
    ax.add_patch(Rectangle((col, position),1,1, fill=False, edgecolor='red', lw=3))
    
plt.savefig('max_column_heatmap.png', dpi = 500, bbox_inches='tight')

enter image description here

Highlighting maximum in a row:

fig,ax = plt.subplots(figsize=(15, 3), facecolor='w', edgecolor='k')
ax = seaborn.heatmap(arr_data_frame, annot=True, vmax=1.0, vmin=0, cmap='Blues', cbar=False, fmt='.4g', ax=ax)

row_max = arr_data_frame.idxmax(axis=1)

for row, index in enumerate(index_text):
    position = arr_data_frame.columns.get_loc(row_max[index])
    ax.add_patch(Rectangle((position, row),1,1, fill=False, edgecolor='red', lw=3))

plt.savefig('max_row_heatmap.png', dpi = 500, bbox_inches='tight')

enter image description here

Upvotes: 3

r-beginners
r-beginners

Reputation: 35155

We've customized the heatmap examples in the official reference. The customization examples were created from the responses from this site. It's a form of adding parts to an existing graph. I added a frame around the maximum value, but this is manual.

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn as sns
sns.set()

# Load the example flights dataset and convert to long-form
flights_long = sns.load_dataset("flights")
flights = flights_long.pivot("month", "year", "passengers")

# Draw a heatmap with the numeric values in each cell
f, ax = plt.subplots(figsize=(9, 6))
ax = sns.heatmap(flights, annot=True, fmt="d", linewidths=.5, ax=ax)

ax.add_patch(Rectangle((10,6),2,2, fill=False, edgecolor='blue', lw=3))

enter image description here

max value:

ymax = max(flights)
ymax
1960

flights.columns.get_loc(ymax)
11
xmax = flights[ymax].idxmax()
xmax
'July'
xpos = flights.index.get_loc(xmax)
xpos
6

ax.add_patch(Rectangle((ymax,xpos),1,1, fill=False, edgecolor='blue', lw=3))

Upvotes: 2

Related Questions