Reputation: 49
Roughly speaking I need to create one legend for several subplots, which have a changing number of categories = legend entries. But let me clarify this a bit deeper:
I have a figure with 20 subplots, one for each country within my spatial scope:
fig, ax = plt.subplots(nrows=4, ncols=5, sharex=True, sharey=False, figsize = (32,18))
Within a loop, I do some logic to group the data I need into a normal 2-dimensional pandas DataFrame stats
and plot it to each of these 20 axes:
colors = stats.index.to_series().map(type_to_color()).tolist()
stats.T.plot.bar(ax=ax[i,j], stacked=True, legend=False, color=colors)
However, the stats
DataFrame is changing its size loop by loop, since not every category applies for each of these countries (i.e. in one country there can only two types, in another there are more than 10).
For this reason I pre-defined a specific color for each type.
So far, I am creating one legend for every subplot within the loop:
ax[i,j].legend(fontsize=9, loc='upper right')
This works, however it blows up the subplots unnecessarily. How can I plot one big legend above/below/beside these plots, since I have already defined the according color.
The given approach here with fig.legend(handles, labels, ...)
does not work since the line handles are not available from the pandas plot.
Plotting the legend directly with
plt.legend(loc = 'lower center',bbox_to_anchor = (0,-0.3,1,1),
bbox_transform = plt.gcf().transFigure)
shows only the entries for the very last subplot, which is not sufficient.
Any help is greatly appreciated! Thank you so much!
Edit
For example the DataFrame stats
could in one country look like this:
2015 2020 2025 2030 2035 2040
Hydro 29.229082 28.964424 28.528139 27.120194 25.932098 24.675778
Natural Gas 0.926800 0.926800 0.926800 0.926800 0.003600 NaN
Wind 25.799950 25.797550 0.776400 0.520800 0.234400 NaN
Whereas in another country it might look like this:
2015 2020 2025 2030 2035
Bioenergy 0.033690 0.033690 0.030000 NaN NaN
Hard Coal 5.307300 0.065100 0.021000 NaN NaN
Hydro 22.834454 23.930642 23.169014 21.639914 19.623791
Natural Gas 8.378116 8.674121 8.013598 6.755498 5.255450
Solar 5.100403 5.100403 5.100403 5.100403 5.093403
Wind 8.983560 8.974740 8.967240 8.378300 0.195800
Upvotes: 0
Views: 550
Reputation: 49
Here's how it works to get the legend into an alphabetical order without messing the colors up:
import matplotlib.patches as mpatches
import collections
fig, ax = plt.subplots(nrows=4, ncols=5, sharex=True, sharey=False, figsize = (32,18))
labels_mpatches = collections.OrderedDict()
for a, b in enumerate(countries())
# do some data logic here
colors = stats.index.to_series().map(type_to_color()).tolist()
stats.T.plot.bar(ax=ax[i,j],stacked=True,legend=False,color=colors)
# Pass the legend information into the OrderedDict
stats_handle, stats_labels = ax[i,j].get_legend_handles_labels()
for u, v in enumerate(stats_labels):
if v not in labels_mpatches:
labels_mpatches[v] = mpatches.Patch(color=colors[u], label=v)
# After the loop, do the legend layouting.
labels_mpatches = collections.OrderedDict(sorted(labels_mpatches.items()))
fig.legend(labels_mpatches.values(), labels_mpatches.keys())
# !!! Please Note: In previous versions this here worked, but does not anymore:
# fig.legend(handles=labels_mpatches.values(),labels=labels_mpatches.keys())
Upvotes: 1