x89
x89

Reputation: 3490

create heatmap with only 2 colors

I have a dataset that looks like this:

profession     Australia_F   Australia_M      Canada_F      Canada_M    Kenya_F   Kenya_M
Author         20            80               55            34          60        23
Librarian      10            34               89            33          89        12
Pilot          78            12               67            90          12        55

I want to plot a sort of heatmap with these values. I tried this:

melted_df = pd.melt(df, id_vars='Profession', var_name='Country_Gender', value_name='Number')



melted_df[['Country', 'Gender']] = melted_df['Country_Gender'].str.split('_', expand=True)
melted_df['Number'] = pd.to_numeric(melted_df['Number'], errors='coerce')

heatmap_data = melted_df.pivot_table(index='Profession', columns=['Country', 'Gender'], values='Number')

plt.figure(figsize=(10, 8))  
sns.heatmap(heatmap_data, cmap='coolwarm', annot=True, fmt=".1f", linewidths=.5)
plt.xlabel('Country and Gender')  
plt.ylabel('Profession')  
plt.xticks(rotation=45)  
plt.tight_layout()  
plt.savefig('heatmap.png')

and it seems to work but currently it assigns different colors to all cells based on the numerical value. However, I only want 2 colors in my chart: red & blue.

what I want is that for each profession (each row), I compare each country's F vs M values and color the higher value cell in red.

For example, for Author, these three cells should be red:

Australia_M (80) Canada_F (55) Kenya_F (60)

while the other 3 in that row should be blue. How can I achieve this?

Upvotes: 1

Views: 79

Answers (2)

JohanC
JohanC

Reputation: 80534

You can use two different dataframes for the coloring and for the text annotations. Creating a copy of the original dataframe, compare the even and the odd columns creates a dataframe of booleans. These booleans (internal values 0 for False and 1 for True) then decide the coloring.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

data = {'Profession': ['Author', 'Librarian', 'Pilot'],
        'Australia_F': [20, 10, 78],
        'Australia_M': [80, 34, 12],
        'Canada_F': [55, 89, 67],
        'Canada_M': [34, 33, 90],
        'Kenya_F': [60, 89, 12],
        'Kenya_M': [23, 12, 55]}
df = pd.DataFrame(data).set_index('Profession')
df_coloring = df.copy()
for colF, colM in zip(df_coloring.columns[::2], df_coloring.columns[1::2]):
    df_coloring[colF] = df[colF] > df[colM]
    df_coloring[colM] = df[colM] > df[colF]

sns.set_style('white')
plt.figure(figsize=(10, 8))
sns.heatmap(df_coloring, cmap='coolwarm', annot=df, fmt=".1f", linewidths=.5, cbar=False)
plt.xlabel('Country and Gender')
plt.ylabel('Profession')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

sns.heatmap with two colors

Optionally, you could add extra separations, put the gender at the top and the country at the bottom:

sns.set_style('white')
plt.figure(figsize=(10, 8))
ax = sns.heatmap(df_coloring, cmap='coolwarm', annot=df, fmt=".0f", linewidths=.5, cbar=False, annot_kws={"size": 22})
countries = [l.get_text()[:-2] for l in ax.get_xticklabels()[::2]]
ax_top = ax.secondary_xaxis('top')
ax_top.set_xticks(ax.get_xticks(), [l.get_text()[-1:] for l in ax.get_xticklabels()])
ax_top.tick_params(length=0)
ax.set_xticks(range(1, len(df.columns), 2), countries)

for i in range(0, len(df.columns) + 1, 2):
    ax.axvline(i, lw=4, color='white')
for i in range(0, len(df) + 1):
    ax.axhline(i, lw=4, color='white')
ax.set_xlabel('Country and Gender')
ax.set_ylabel('Profession')
plt.tight_layout()
plt.show()

sns.heatmap with extra separations

Upvotes: 1

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
data = {
    'Profession': ['Author', 'Librarian', 'Pilot'],
    'Australia_F': [20, 10, 78],
    'Australia_M': [80, 34, 12],
    'Canada_F': [55, 89, 67],
    'Canada_M': [34, 33, 90],
    'Kenya_F': [60, 89, 12],
    'Kenya_M': [23, 12, 55]
}
df = pd.DataFrame(data)
melted_df = pd.melt(df, id_vars='Profession', var_name='Country_Gender', value_name='Number')
melted_df[['Country', 'Gender']] = melted_df['Country_Gender'].str.split('_', expand=True)
melted_df['Number'] = pd.to_numeric(melted_df['Number'], errors='coerce')
def assign_color(row):
    if row['Gender'] == 'F':
        return 'red' if row['Number'] > melted_df[(melted_df['Profession'] == row['Profession']) & (melted_df['Country'] == row['Country']) & (melted_df['Gender'] == 'M')]['Number'].values[0] else 'blue'
    else:
        return 'red' if row['Number'] > melted_df[(melted_df['Profession'] == row['Profession']) & (melted_df['Country'] == row['Country']) & (melted_df['Gender'] == 'F')]['Number'].values[0] else 'blue'
melted_df['Color'] = melted_df.apply(assign_color, axis=1)
heatmap_data = melted_df.pivot_table(index='Profession', columns=['Country', 'Gender'], values='Number')
plt.figure(figsize=(10, 8))  
sns.heatmap(heatmap_data, cmap='coolwarm', annot=True, fmt=".1f", linewidths=.5, cbar=False, square=True, mask=heatmap_data.isna(), annot_kws={"fontsize":10}, center=50)
plt.xlabel('Country and Gender')  
plt.ylabel('Profession')  
plt.xticks(rotation=45)  
plt.tight_layout()  
plt.savefig('heatmap.png')

Upvotes: 0

Related Questions