Reputation: 135
I have to draw 2 subplots from 2 df that have different columns. I want to get common legend that contains all the columns of the two, similarly to the example below.
d1 = pd.DataFrame({ #without one
'two' : [-1.,- 2.,- 3., -4.],
'three' : [4., 3., 2., 1.],
'four' : [4., 3., 4., 3.]})
tot_1=d1.sum(axis=1)
d2 = pd.DataFrame({'one' : [1., 2., 3., 4.],
'two' : [4., 3., 3., 1.],
'three' : [-1., -1., -3., -4.],
'four' : [4., 3., 2., 1.]})
tot_2=d2.sum(axis=1)
fig, ax = plt.subplots(nrows=2, ncols=1)
#plot 1
d1.plot.area(stacked=True,legend=False,ax=ax[0])
tot_1.plot(linestyle='-', color='black',legend=False,ax=ax[0])
###SECOND GRAPH####
ax3 = ax[1].twiny()
#plot 2
d2.plot.area(stacked=True,legend=False,ax=ax[1],sharex=ax[0])
tot_2.plot(linestyle='-',color='black',legend=False,ax=ax[1])
plt.show()
The prob is that the columns are not exactly the same in the two dataframes/plots (same are in both some not) and should make sure that the legend has all the columns and the colors in the two plots and the legend match.
Even better (but not necessary) would be if I could chose the color for each column, for instance using a dictionary with col_name:color
to pass
Upvotes: 0
Views: 1424
Reputation: 339220
You can indeed use a dictionary of column-name:color
pairs to colorize the patches and to afterwards create a legend from it.
import pandas as pd
import matplotlib.pyplot as plt
d1 = pd.DataFrame({ #without one
'two' : [-1.,- 2.,- 3., -4.],
'three' : [4., 3., 2., 1.],
'four' : [4., 3., 4., 3.]})
tot_1=d1.sum(axis=1)
d2 = pd.DataFrame({'one' : [1., 2., 3., 4.],
'two' : [4., 3., 3., 1.],
'three' : [-1., -1., -3., -4.],
'four' : [4., 3., 2., 1.]})
tot_2=d2.sum(axis=1)
columns = ["one", "two", "three", "four"]
colors = dict(zip(columns, ["C"+str(i) for i in range(len(columns)) ]))
fig, ax = plt.subplots(nrows=2, ncols=1)
#plot 1
d1.plot.area(stacked=True,legend=False,ax=ax[0], lw=0,
color=[colors[i] for i in d1.columns])
tot_1.plot(linestyle='-', color='black',legend=False,ax=ax[0])
###SECOND GRAPH####
ax3 = ax[1].twiny()
#plot 2
d2.plot.area(stacked=True,legend=False,ax=ax[1],sharex=ax[0], lw=0,
color=[colors[i] for i in d2.columns])
tot_2.plot(linestyle='-',color='black',legend=False,ax=ax[1])
labels = list(set(list(d1.columns) + list(d2.columns)))
handles = [plt.Rectangle((0,0),1,1, color=colors[l]) for l in labels]
ax3.legend(handles=handles, labels=labels)
plt.show()
Upvotes: 2