everestial
everestial

Reputation: 7255

Create a single legend for multiple seaborn plots

I am making boxplot using "iris.csv" data. I am trying to break the data into multiple dataframe by measurements (i.e petal-length, petal-width, sepal-length, sepal-width) and then make box-plot on a forloop, thereby adding subplot.

Finally, I want to add a common legend for all the box plot at once. But, I am not able to do it. I have tried several tutorials and methods using several stackoverflow questions, but i am not able to fix it.

Here is my code:

import seaborn as sns 
from matplotlib import pyplot

iris_data = "iris.csv"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = read_csv(iris_data, names=names)


# Reindex the dataset by species so it can be pivoted for each species 
reindexed_dataset = dataset.set_index(dataset.groupby('class').cumcount())
cols_to_pivot = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width']

# empty dataframe 
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
    pivoted_dataset = reindexed_dataset.pivot(columns='class', values=var_name).rename_axis(None,axis=1)
    pivoted_dataset['measurement'] = var_name
    reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)


## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement') :
    grouped_dfs_02.append(group[1])


## make the box plot of several measured variables, compared between species 

pyplot.figure(figsize=(20, 5), dpi=80)
pyplot.suptitle('Distribution of floral traits in the species of iris')

sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')

my_pal = {"Iris-versicolor": "g", "Iris-setosa": "r", "Iris-virginica":"b"}
plt_index = 0


# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):

    axi = pyplot.subplot(1, len(grouped_dfs_02), plt_index + 1)
    sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
    df_melt = df.melt('measurement', var_name='species', value_name='values')

    sns.boxplot(data=df_melt, x='species', y='values', ax = axi, orient="v", palette=my_pal)
    pyplot.title(group_name)
    plt_index += 1


# Move the legend to an empty part of the plot
pyplot.legend(title='species', labels = sp_name, 
         handles=[setosa, versi, virgi], bbox_to_anchor=(19, 4),
           fancybox=True, shadow=True, ncol=5)


pyplot.show()

Here is the plot: enter image description here

How, do I add a common legend to the main figure, outside the main frame, by the side of the "main suptitle"?

Upvotes: 1

Views: 9975

Answers (2)

Trenton McKinney
Trenton McKinney

Reputation: 62523

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# load iris data
iris = sns.load_dataset("iris")

   sepal_length  sepal_width  petal_length  petal_width species
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa

# create figure
fig, axes = plt.subplots(ncols=4, figsize=(20, 5), sharey=True)

# add subplots
for ax, col in zip(axes, iris.columns[:-1]):
    sns.boxplot(x='species', y=col, data=iris, hue='species', dodge=False, ax=ax)
    ax.get_legend().remove()
    ax.set_title(col)

# add legend
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', ncol=3, bbox_to_anchor=(0.8, 1), frameon=False)

# add subtitle
fig.suptitle('Distribution of floral traits in the species of iris')

plt.show()

enter image description here


  • However, a legend is not required, and redundantly conveys the same information, because the colors across each plot are the same, and the label for each is already on the x-axis.
  • The more succinct option is to convert the dataframe to long format with pandas.DataFrame.melt, and then to plot with sns.catplot and kind='box'.
dfm = iris.melt(id_vars='species', var_name='parameter', value_name='measurement', ignore_index=True)

  species     parameter  measurement
0  setosa  sepal_length          5.1
1  setosa  sepal_length          4.9
2  setosa  sepal_length          4.7
3  setosa  sepal_length          4.6
4  setosa  sepal_length          5.0

g = sns.catplot(kind='box', data=dfm, x='species', y='measurement', hue='species', col='parameter', dodge=False)
_ = g.fig.suptitle('Distribution of floral traits in the species of iris', y=1.1)

enter image description here

  • Optionally, plot all the values in a single subplot, which makes comparing the 'parameter' across 'species' easier.
g = sns.catplot(kind='box', data=dfm, x='parameter', y='measurement', hue='species', height=4, aspect=2)

_ = g.fig.suptitle('Distribution of floral traits in the species of iris', y=1.1)

enter image description here

Upvotes: 9

JohanC
JohanC

Reputation: 80509

To position the legend, it is important to set the loc parameter, being the anchor point. (The default loc is 'best' which means you don't know beforehand where it would end up). The positions are measured from 0,0 being the lower left of the current ax, to 1,1: the upper left of the current ax. This doesn't include the padding for titles etc., so the values can go a bit outside the 0, 1 range. The "current ax" is the last one that was activated.

Note that instead of plt.legend (which uses an axes), you could also use plt.gcf().legend which uses the "figure". Then, the coordinates are 0,0 in lower left corner of the complete plot (the "figure") and 1,1 in the upper right. A drawback would be that no extra space would be created for the legend, so you'd need to manually set a top padding (e.g. plt.gcf().subplots_adjust(top=0.8)). A drawback would be that you can't use plt.tight_layout() anymore, and that it would be harder to align the legend with the axes.

import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import patches as mpatches
import pandas as pd

dataset = sns.load_dataset("iris")

# Reindex the dataset by species so it can be pivoted for each species
reindexed_dataset = dataset.set_index(dataset.groupby('species').cumcount())
cols_to_pivot = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']

# empty dataframe
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
    pivoted_dataset = reindexed_dataset.pivot(columns='species', values=var_name).rename_axis(None, axis=1)
    pivoted_dataset['measurement'] = var_name
    reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)

## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement'):
    grouped_dfs_02.append(group[1])

## make the box plot of several measured variables, compared between species
plt.figure(figsize=(20, 5), dpi=80)
plt.suptitle('Distribution of floral traits in the species of iris')

sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')

my_pal = {"versicolor": "g", "setosa": "r", "virginica": "b"}
plt_index = 0

# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):
    axi = plt.subplot(1, len(grouped_dfs_02), plt_index + 1)
    sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
    df_melt = df.melt('measurement', var_name='species', value_name='values')

    sns.boxplot(data=df_melt, x='species', y='values', ax=axi, orient="v", palette=my_pal)
    plt.title(group_name)
    plt_index += 1

# Move the legend to an empty part of the plot
plt.legend(title='species', labels=sp_name,
           handles=[setosa, versi, virgi], bbox_to_anchor=(1, 1.23),
           fancybox=True, shadow=True, ncol=5, loc='upper right')
plt.tight_layout()
plt.show()

resulting plot

Upvotes: 1

Related Questions