stack4science
stack4science

Reputation: 49

Single legend at changing categories (!) in subplots from pandas df

Roughly speaking I need to create one legend for several subplots, which have a changing number of categories = legend entries. But let me clarify this a bit deeper:

I have a figure with 20 subplots, one for each country within my spatial scope:

fig, ax = plt.subplots(nrows=4, ncols=5, sharex=True, sharey=False, figsize = (32,18))

Within a loop, I do some logic to group the data I need into a normal 2-dimensional pandas DataFrame stats and plot it to each of these 20 axes:

colors = stats.index.to_series().map(type_to_color()).tolist()
stats.T.plot.bar(ax=ax[i,j], stacked=True, legend=False, color=colors)

However, the stats DataFrame is changing its size loop by loop, since not every category applies for each of these countries (i.e. in one country there can only two types, in another there are more than 10). For this reason I pre-defined a specific color for each type. So far, I am creating one legend for every subplot within the loop:

ax[i,j].legend(fontsize=9, loc='upper right')

This works, however it blows up the subplots unnecessarily. How can I plot one big legend above/below/beside these plots, since I have already defined the according color. The given approach here with fig.legend(handles, labels, ...)does not work since the line handles are not available from the pandas plot. Plotting the legend directly with

plt.legend(loc = 'lower center',bbox_to_anchor = (0,-0.3,1,1),
    bbox_transform = plt.gcf().transFigure)

shows only the entries for the very last subplot, which is not sufficient.

Any help is greatly appreciated! Thank you so much!

Edit For example the DataFrame stats could in one country look like this:

              2015       2020       2025       2030       2035       2040                                                                          
Hydro        29.229082  28.964424  28.528139  27.120194  25.932098  24.675778   
Natural Gas   0.926800   0.926800   0.926800   0.926800   0.003600        NaN   
Wind         25.799950  25.797550   0.776400   0.520800   0.234400        NaN   

Whereas in another country it might look like this:

              2015        2020        2025        2030        2035                                                                 
Bioenergy     0.033690    0.033690    0.030000         NaN         NaN   
Hard Coal     5.307300    0.065100    0.021000         NaN         NaN   
Hydro        22.834454   23.930642   23.169014   21.639914   19.623791   
Natural Gas   8.378116    8.674121    8.013598    6.755498    5.255450   
Solar         5.100403    5.100403    5.100403    5.100403    5.093403   
Wind          8.983560    8.974740    8.967240    8.378300    0.195800 

Upvotes: 0

Views: 550

Answers (1)

stack4science
stack4science

Reputation: 49

Here's how it works to get the legend into an alphabetical order without messing the colors up:

import matplotlib.patches as mpatches
import collections

fig, ax = plt.subplots(nrows=4, ncols=5, sharex=True, sharey=False, figsize = (32,18))
labels_mpatches = collections.OrderedDict()

for a, b in enumerate(countries())
    # do some data logic here
    colors = stats.index.to_series().map(type_to_color()).tolist()
    stats.T.plot.bar(ax=ax[i,j],stacked=True,legend=False,color=colors)
    # Pass the legend information into the OrderedDict
    stats_handle, stats_labels = ax[i,j].get_legend_handles_labels()
        for u, v in enumerate(stats_labels):
            if v not in labels_mpatches:
                labels_mpatches[v] = mpatches.Patch(color=colors[u], label=v)
# After the loop, do the legend layouting.
labels_mpatches = collections.OrderedDict(sorted(labels_mpatches.items()))
fig.legend(labels_mpatches.values(), labels_mpatches.keys())
# !!! Please Note: In previous versions this here worked, but does not anymore:
# fig.legend(handles=labels_mpatches.values(),labels=labels_mpatches.keys())

Upvotes: 1

Related Questions