John Conor
John Conor

Reputation: 884

How to change the legend location for pandas subplots

I'm trying to create a pandas plot with a large number of subplots, 58 in this case. The data is wide form in a format similar to this:

df = 

Date It1 It2 It3... Itn 
0    x    x   x      n
1    x    x   x      n
2    x    x   x      n
3    x    x   x      n

I have been able to create the plot no problem with pandas plot:

rows = df.shape[1]//2
df.plot(legend = True, subplots = True, layout = (rows,5), grid=True, title="Labs", sharex=True, sharey=False,figsize=(12,32),)
plt.show()

But am having trouble setting the position of the legend so all of the graphs are legible, this is an example of how the currently look:

enter image description here

I've tried both solutions in this other stack overflow post - Set the legend location of a pandas plot

... but neither actually work. I also tried using tight_layout() per this answer but it is equally illegible - Plot pandas dataframe with subplots (subplots=True): Place legend and use tight layout

Can anyone offer any guidance as to how to place the legends of a chart with so many graphs on it and still keep it readable?

Upvotes: 2

Views: 2569

Answers (1)

Trenton McKinney
Trenton McKinney

Reputation: 62383

  • pandas.DataFrame.plot with subplots=True returns a numpy.ndarray of matplotlib.axes.Axes
  • The easiest way to access each subplot axes is to flatten the array, and iterate through each.
  • Use the answers to How to put the legend out of the plot to place the legend in an appropriate location.
  • Other axes level modifications can be made inside the loop using the standard matplotlib object orient methods (i.e. those beginning with ax.).
  • figsize must be adjusted depending the number of rows and cols.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# # sinusoidal sample data
sample_length = range(1, 15+1)
rads = np.arange(0, 2*np.pi, 0.01)
data = np.array([np.sin(t*rads) for t in sample_length])
df = pd.DataFrame(data.T, index=pd.Series(rads.tolist(), name='radians'), columns=[f'freq: {i}x' for i in sample_length])

# plot the data with subplots and assign the returned array
axes = df.plot(subplots=True, layout=(3, 5), figsize=(25, 15))

# flatten the array
axes = axes.flat  # .ravel() and .flatten() also work

# extract the figure object to use figure level methods
fig = axes[0].get_figure()

# iterate through each axes to use axes level methods
for ax in axes:
    
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.10), frameon=False)
    
fig.suptitle('Sinusoids of Different Frequency', fontsize=22, y=0.95)
plt.show()

enter image description here

Upvotes: 4

Related Questions