Giuseppe Cardellini
Giuseppe Cardellini

Reputation: 135

Create unique legend for subplot with different columns in pandas

I have to draw 2 subplots from 2 df that have different columns. I want to get common legend that contains all the columns of the two, similarly to the example below.

d1 = pd.DataFrame({ #without one
    'two' : [-1.,- 2.,- 3., -4.],
    'three' : [4., 3., 2., 1.],
    'four' : [4., 3., 4., 3.]})
tot_1=d1.sum(axis=1)

d2 = pd.DataFrame({'one' : [1., 2., 3., 4.],
    'two' : [4., 3., 3., 1.],
    'three' : [-1., -1., -3., -4.],
    'four' : [4., 3., 2., 1.]})

tot_2=d2.sum(axis=1)


fig, ax = plt.subplots(nrows=2, ncols=1)

#plot 1
d1.plot.area(stacked=True,legend=False,ax=ax[0])
tot_1.plot(linestyle='-', color='black',legend=False,ax=ax[0])

###SECOND GRAPH####

ax3 = ax[1].twiny()

#plot 2
d2.plot.area(stacked=True,legend=False,ax=ax[1],sharex=ax[0])
tot_2.plot(linestyle='-',color='black',legend=False,ax=ax[1])

plt.show()

enter image description here

The prob is that the columns are not exactly the same in the two dataframes/plots (same are in both some not) and should make sure that the legend has all the columns and the colors in the two plots and the legend match.

Even better (but not necessary) would be if I could chose the color for each column, for instance using a dictionary with col_name:color to pass

Upvotes: 0

Views: 1424

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339220

You can indeed use a dictionary of column-name:color pairs to colorize the patches and to afterwards create a legend from it.

import pandas as pd
import matplotlib.pyplot as plt

d1 = pd.DataFrame({ #without one
    'two' : [-1.,- 2.,- 3., -4.],
    'three' : [4., 3., 2., 1.],
    'four' : [4., 3., 4., 3.]})
tot_1=d1.sum(axis=1)

d2 = pd.DataFrame({'one' : [1., 2., 3., 4.],
    'two' : [4., 3., 3., 1.],
    'three' : [-1., -1., -3., -4.],
    'four' : [4., 3., 2., 1.]})
tot_2=d2.sum(axis=1)

columns = ["one", "two", "three", "four"]
colors = dict(zip(columns, ["C"+str(i) for i in range(len(columns)) ]))

fig, ax = plt.subplots(nrows=2, ncols=1)

#plot 1
d1.plot.area(stacked=True,legend=False,ax=ax[0], lw=0,
             color=[colors[i] for i in d1.columns])
tot_1.plot(linestyle='-', color='black',legend=False,ax=ax[0])

###SECOND GRAPH####

ax3 = ax[1].twiny()

#plot 2
d2.plot.area(stacked=True,legend=False,ax=ax[1],sharex=ax[0], lw=0,
             color=[colors[i] for i in d2.columns])
tot_2.plot(linestyle='-',color='black',legend=False,ax=ax[1])

labels = list(set(list(d1.columns) + list(d2.columns)))

handles = [plt.Rectangle((0,0),1,1, color=colors[l]) for l in labels]
ax3.legend(handles=handles, labels=labels)
plt.show()

enter image description here

Upvotes: 2

Related Questions