Azlin Sohail
Azlin Sohail

Reputation: 11

How to iteratively remove X axis labels from multiple subplots

Using seaborn I have created 21 subplots, my code is as follows:

fig, axes = plt.subplots(7, 3, figsize=(25, 25))

fig.suptitle('Workforce Statistics')

sns.lineplot(ax=axes[0, 0], data=dfStaff, x='Month', y='Central functions')
sns.lineplot(ax=axes[0, 1], data=dfStaff, x='Month', y='Support to ST&T staff')
sns.lineplot(ax=axes[0, 2], data=dfStaff, x='Month', y='Consultant')
sns.lineplot(ax=axes[1, 0], data=dfStaff, x='Month', y='Specialty Registrar')
sns.lineplot(ax=axes[1, 1], data=dfStaff, x='Month', y='Midwives')
sns.lineplot(ax=axes[1, 2], data=dfStaff, x='Month', y='Managers')
sns.lineplot(ax=axes[2, 0], data=dfStaff, x='Month', y='Ambulance staff')
sns.lineplot(ax=axes[2, 1], data=dfStaff, x='Month', y='Support to ambulance staff')
sns.lineplot(ax=axes[2, 2], data=dfStaff, x='Month', y='Senior managers')
sns.lineplot(ax=axes[3, 0], data=dfStaff, x='Month', y='Core Training')
sns.lineplot(ax=axes[3, 1], data=dfStaff, x='Month', y='Specialty Doctor')
sns.lineplot(ax=axes[3, 2], data=dfStaff, x='Month', y='Foundation Doctor Year 1')
sns.lineplot(ax=axes[4, 0], data=dfStaff, x='Month', y='Foundation Doctor Year 2')
sns.lineplot(ax=axes[4, 1], data=dfStaff, x='Month', y='Other staff or those with unknown classification')
sns.lineplot(ax=axes[4, 2], data=dfStaff, x='Month', y='Associate Specialist')
sns.lineplot(ax=axes[5, 0], data=dfStaff, x='Month', y='Hospital Practitioner / Clinical Assistant')
sns.lineplot(ax=axes[5, 1], data=dfStaff, x='Month', y='Other and Local HCHS Doctor Grades')
sns.lineplot(ax=axes[5, 2], data=dfStaff, x='Month', y='Staff Grade')
sns.lineplot(ax=axes[6, 0], data=dfStaff, x='Month', y='Nurses & health visitors')
sns.lineplot(ax=axes[6, 1], data=dfStaff, x='Month', y='Support to doctors, nurses & midwives')
sns.lineplot(ax=axes[6, 2], data=dfStaff, x='Month', y='HCHS doctors')

axes[0,0].xaxis.set_major_locator(MaxNLocator(6)) 
axes[0,1].xaxis.set_major_locator(MaxNLocator(6)) 
axes[0,2].xaxis.set_major_locator(MaxNLocator(6)) 
axes[1,0].xaxis.set_major_locator(MaxNLocator(6)) 
axes[1,1].xaxis.set_major_locator(MaxNLocator(6)) 
axes[1,2].xaxis.set_major_locator(MaxNLocator(6)) 
axes[2,0].xaxis.set_major_locator(MaxNLocator(6)) 
axes[2,1].xaxis.set_major_locator(MaxNLocator(6)) 
axes[2,2].xaxis.set_major_locator(MaxNLocator(6)) 
axes[3,0].xaxis.set_major_locator(MaxNLocator(6)) 
axes[3,1].xaxis.set_major_locator(MaxNLocator(6)) 
axes[3,2].xaxis.set_major_locator(MaxNLocator(6)) 
axes[4,0].xaxis.set_major_locator(MaxNLocator(6)) 
axes[4,1].xaxis.set_major_locator(MaxNLocator(6)) 
axes[4,2].xaxis.set_major_locator(MaxNLocator(6)) 
axes[5,0].xaxis.set_major_locator(MaxNLocator(6)) 
axes[5,1].xaxis.set_major_locator(MaxNLocator(6)) 
axes[5,2].xaxis.set_major_locator(MaxNLocator(6)) 
axes[6,0].xaxis.set_major_locator(MaxNLocator(6)) 
axes[6,1].xaxis.set_major_locator(MaxNLocator(6)) 
axes[6,2].xaxis.set_major_locator(MaxNLocator(6)) 

For the second part, I am trying to create a for loop to iterate all the axes and set_major_locator, but keep running into errors.

Upvotes: 1

Views: 167

Answers (1)

Trenton McKinney
Trenton McKinney

Reputation: 62503

  1. Flatten axes
  2. Create a list, y_cols, of all the columns to be used for y
  3. Iterate through axes and y_cols
  • The following code runs with df from Working Example
  • Tested using python 3.8.11, pandas 1.3.1, matplotlib 3.4.2, and seaborn 0.11.1.
fig, axes = plt.subplots(3, 2, figsize=(12, 6))

# flatten axes into a 1D array, which is easier to iterate through
axes = axes.flatten()

# specify the y columns in a list
y_cols = df.columns[1:]

fig.suptitle('Workforce Statistics')

for ax, y in zip(axes, y_cols):
    
    sns.lineplot(ax=ax, data=df, x='Month', y=y)
    ax.set(title=y, ylabel='Something', xlabel='Date')

    ax.xaxis.set_major_locator(plt.MaxNLocator(5)) 
    
fig.tight_layout()

enter image description here


  • Seaborn is just a high-level API for matplotlib
  • Alternatively, use pandas.DataFrame.plot since you're plotting a dataframe.
    • If y= is not specified, all columns other than 'Month' will be plotted. Otherwise, create the column list and pass it to y=y_cols
    • Uses matplotlib as the backend
axes = dfStaff.plot(x='Month', subplots=True, layout=(7, 3), figsize=(25, 25))
axes = axes.flatten()
for ax in axes:
    ax.xaxis.set_major_locator(plt.MaxNLocator(6)) 

Working example

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# sample data
np.random.seed(365)
rows = 365*3
data = {'Month': pd.bdate_range('2017-01-10', freq='D', periods=rows),
        'a': np.random.randint(0, 10, size=(rows)),
        'b': np.random.randint(15, 25, size=(rows)),
        'c': np.random.randint(30, 40, size=(rows)),
        'd': np.random.randint(450, 550, size=(rows)),
        'e': np.random.randint(6000, 7000, size=(rows)),
        'f': np.random.randint(100, 201, size=(rows))}
df = pd.DataFrame(data)
df.head()

# display(df.head())
       Month  a   b   c    d     e    f
0 2017-01-10  2  17  36  480  6539  101
1 2017-01-11  4  18  30  482  6955  152
2 2017-01-12  1  16  30  504  6472  105
3 2017-01-13  5  17  32  519  6269  113
4 2017-01-14  2  17  37  534  6654  160

# plot
axes = df.plot(x='Month', subplots=True, layout=(2, 3), figsize=(15, 6), title='Workforce Statistics - with MaxNLocator', xlabel='Date')
axes = axes.flatten()
for ax in axes:
    ax.xaxis.set_major_locator(plt.MaxNLocator(6)) 

enter image description here

enter image description here

Upvotes: 2

Related Questions