Reputation: 7255
I am making boxplot using "iris.csv" data. I am trying to break the data into multiple dataframe by measurements (i.e petal-length, petal-width, sepal-length, sepal-width) and then make box-plot on a forloop, thereby adding subplot.
Finally, I want to add a common legend for all the box plot at once. But, I am not able to do it. I have tried several tutorials and methods using several stackoverflow questions, but i am not able to fix it.
Here is my code:
import seaborn as sns
from matplotlib import pyplot
iris_data = "iris.csv"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = read_csv(iris_data, names=names)
# Reindex the dataset by species so it can be pivoted for each species
reindexed_dataset = dataset.set_index(dataset.groupby('class').cumcount())
cols_to_pivot = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width']
# empty dataframe
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
pivoted_dataset = reindexed_dataset.pivot(columns='class', values=var_name).rename_axis(None,axis=1)
pivoted_dataset['measurement'] = var_name
reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)
## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement') :
grouped_dfs_02.append(group[1])
## make the box plot of several measured variables, compared between species
pyplot.figure(figsize=(20, 5), dpi=80)
pyplot.suptitle('Distribution of floral traits in the species of iris')
sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')
my_pal = {"Iris-versicolor": "g", "Iris-setosa": "r", "Iris-virginica":"b"}
plt_index = 0
# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):
axi = pyplot.subplot(1, len(grouped_dfs_02), plt_index + 1)
sp_name=['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
df_melt = df.melt('measurement', var_name='species', value_name='values')
sns.boxplot(data=df_melt, x='species', y='values', ax = axi, orient="v", palette=my_pal)
pyplot.title(group_name)
plt_index += 1
# Move the legend to an empty part of the plot
pyplot.legend(title='species', labels = sp_name,
handles=[setosa, versi, virgi], bbox_to_anchor=(19, 4),
fancybox=True, shadow=True, ncol=5)
pyplot.show()
How, do I add a common legend to the main figure, outside the main frame, by the side of the "main suptitle"?
Upvotes: 1
Views: 9975
Reputation: 62523
matplotlib.pyplot.legend
and matplotlib.axes.Axes.legend
loc
bbox_to_anchor
sns.histplot
, because .get_legend_handles_labels()
does not work for it.seaborn
and it loads as a dataframe.python 3.11.4
, pandas 2.0.3
, matplotlib 3.7.1
, seaborn 0.12.2
seaborn 0.13.0
, legend=True
may be required in the sns.boxplot
call.import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# load iris data
iris = sns.load_dataset("iris")
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
# create figure
fig, axes = plt.subplots(ncols=4, figsize=(20, 5), sharey=True)
# add subplots
for ax, col in zip(axes, iris.columns[:-1]):
sns.boxplot(x='species', y=col, data=iris, hue='species', dodge=False, ax=ax)
ax.get_legend().remove()
ax.set_title(col)
# add legend
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', ncol=3, bbox_to_anchor=(0.8, 1), frameon=False)
# add subtitle
fig.suptitle('Distribution of floral traits in the species of iris')
plt.show()
pandas.DataFrame.melt
, and then to plot with sns.catplot
and kind='box'
.dfm = iris.melt(id_vars='species', var_name='parameter', value_name='measurement', ignore_index=True)
species parameter measurement
0 setosa sepal_length 5.1
1 setosa sepal_length 4.9
2 setosa sepal_length 4.7
3 setosa sepal_length 4.6
4 setosa sepal_length 5.0
g = sns.catplot(kind='box', data=dfm, x='species', y='measurement', hue='species', col='parameter', dodge=False)
_ = g.fig.suptitle('Distribution of floral traits in the species of iris', y=1.1)
'parameter'
across 'species'
easier.g = sns.catplot(kind='box', data=dfm, x='parameter', y='measurement', hue='species', height=4, aspect=2)
_ = g.fig.suptitle('Distribution of floral traits in the species of iris', y=1.1)
Upvotes: 9
Reputation: 80509
To position the legend, it is important to set the loc
parameter, being the anchor point. (The default loc
is 'best'
which means you don't know beforehand where it would end up). The positions are measured from 0,0
being the lower left of the current ax, to 1,1
: the upper left of the current ax. This doesn't include the padding for titles etc., so the values can go a bit outside the 0, 1
range. The "current ax" is the last one that was activated.
Note that instead of plt.legend
(which uses an axes), you could also use plt.gcf().legend
which uses the "figure". Then, the coordinates are 0,0
in lower left corner of the complete plot (the "figure") and 1,1
in the upper right. A drawback would be that no extra space would be created for the legend, so you'd need to manually set a top padding (e.g. plt.gcf().subplots_adjust(top=0.8)
). A drawback would be that you can't use plt.tight_layout()
anymore, and that it would be harder to align the legend with the axes.
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import patches as mpatches
import pandas as pd
dataset = sns.load_dataset("iris")
# Reindex the dataset by species so it can be pivoted for each species
reindexed_dataset = dataset.set_index(dataset.groupby('species').cumcount())
cols_to_pivot = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
# empty dataframe
reshaped_dataset = pd.DataFrame()
for var_name in cols_to_pivot:
pivoted_dataset = reindexed_dataset.pivot(columns='species', values=var_name).rename_axis(None, axis=1)
pivoted_dataset['measurement'] = var_name
reshaped_dataset = reshaped_dataset.append(pivoted_dataset, ignore_index=True)
## Now, lets spit the dataframe into groups by-measurements.
grouped_dfs_02 = []
for group in reshaped_dataset.groupby('measurement'):
grouped_dfs_02.append(group[1])
## make the box plot of several measured variables, compared between species
plt.figure(figsize=(20, 5), dpi=80)
plt.suptitle('Distribution of floral traits in the species of iris')
sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
setosa = mpatches.Patch(color='red')
versi = mpatches.Patch(color='green')
virgi = mpatches.Patch(color='blue')
my_pal = {"versicolor": "g", "setosa": "r", "virginica": "b"}
plt_index = 0
# for i, df in enumerate(grouped_dfs_02):
for group_name, df in reshaped_dataset.groupby('measurement'):
axi = plt.subplot(1, len(grouped_dfs_02), plt_index + 1)
sp_name = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
df_melt = df.melt('measurement', var_name='species', value_name='values')
sns.boxplot(data=df_melt, x='species', y='values', ax=axi, orient="v", palette=my_pal)
plt.title(group_name)
plt_index += 1
# Move the legend to an empty part of the plot
plt.legend(title='species', labels=sp_name,
handles=[setosa, versi, virgi], bbox_to_anchor=(1, 1.23),
fancybox=True, shadow=True, ncol=5, loc='upper right')
plt.tight_layout()
plt.show()
Upvotes: 1