SH_Clarity
SH_Clarity

Reputation: 73

How to add legends to boxplots with seaborn

I have a set of data for which I need to make boxplots. I am having troubles adding legends, with colors appropriate with each single box. Here is my code:

model_names=["a","b","c","d","e","f","g","h"]
fig = plt.figure()
legend=model_names
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
sns.boxplot(data=df_A, ax=ax1, labels=model_names) # Adding model_names here gives errors.
ax1.set_xticklabels([])
ax1.tick_params(axis = "x", which = "both", bottom = False, top = False)

ax1.set_xlabel("A", fontsize=16)
ax1.set_ylabel("Score", fontsize=16)

sns.boxplot(data=df_B, ax=ax2)
ax2.set_xticklabels([])
ax2.tick_params(axis = "x", which = "both", bottom = False, top = False)
ax2.set_xlabel("B", fontsize=16)

Here is my dataframe, df_A:

enter image description here

and here is the plot, without legends:

enter image description here

Upvotes: 1

Views: 3720

Answers (1)

JohanC
JohanC

Reputation: 80289

Seaborn prefers one combined dataframe in "long form". With such a dataframe, a catplot(kind='box', ...) can be created using the old column names both for the x as for the hue. The hue names will end up in the legend.

Here is some example code, showing how to create the "long form" and combine the dataframes.

from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

model_names = ["a", "b", "c", "d", "e", "f", "g", "h"]

df_A = pd.DataFrame(np.random.rand(15, 8), columns=model_names)
df_B = pd.DataFrame(np.random.rand(20, 8), columns=model_names)
df_A_long = df_A.melt()
df_A_long['source'] = 'A'
df_B_long = df_B.melt()
df_B_long['source'] = 'B'
df_combined = df_A_long.append(df_B_long)

g = sns.catplot(kind='box', data=df_combined, col='source', x='variable', y='value', hue='variable',
                dodge=False, palette=sns.color_palette("Set2"), legend_out=True)
g.add_legend()
plt.setp(g.axes, xticks=[], xlabel='') # remove x ticks and xlabel
g.fig.subplots_adjust(left=0.06) # more space for the y-label
plt.show()

box plots of multiple dataframes

dfA_long is created via pd.melt() which puts the column names in a new column called "variable" and a column named "value" with the corresponding values. After adding the "source" column it looks like:

    variable     value source
0          a  0.581008      A
1          a  0.037324      A
2          a  0.833181      A
....

Note that in general having the x-values show the name of each box is clearer than working with a color and referring to a legend. More indirection makes it harder for the viewer to quickly understand the plot. But if e.g. the names would be too long to fit nicely on the x-axis, or when there are many similar subplots, a legend can be a great solution.

Upvotes: 2

Related Questions