Florian
Florian

Reputation: 311

Share axis and remove unused in matplotlib subplots

I want to plot a series of seaborn heatmaps in a grid. I know the number of subplots (which can be odd or even). The heatmaps will show the mean "occupation ratio" by "day of week" (y axis) and "hour of day" (x axis), e.g. they all share the same x / y domains.

Here's my current code:

df2 = df[['name','openLots','occupationRatio','DoW','Hour']]
fig, axs = plt.subplots(figsize=(24,24), nrows=7, ncols=6)
axs = axs.flatten()
locations = df2['name'].sort_values().unique()


def occupation_heatmap (name, ax):
    dfn = df2[df2['name'] == name]
    dfn = dfn.groupby(['DoW', 'Hour']).mean()['occupationRatio'].unstack()
    dfn = dfn.reindex(['Mon', 'Tue', 'Wed','Thu','Fri','Sat','Sun'])
    sns.heatmap(data=dfn, cmap="coolwarm", vmin=0, vmax=1.0, ax= ax)
    ax.set_title(name)


i = 0
for n in locations: 
    occupation_heatmap (n, axs[i])
    i = i+1

plt.tight_layout()

It looks almost like what I want (last few rows): note axis labels, legend and "empty" subplots However want I want:

Many thanks for any hints

Upvotes: 2

Views: 4036

Answers (2)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339550

  • Have the y axis labels (DoW) only once per row (leftmost plot)
    This can be done using sharey = True as argument to plt.subplots.
  • Have the colormap legend only on the rightmost plot in each row (or leave it out completely, the colors are pretty self-explainatory)
    Use the cbar = False argument to seaborn.heatmap in order not to show a colorbar. This can be given as an input to the plotting function in dependence of the actual number of subplots.
  • remove the "empty plots" in the last row because of an odd total number
    After the loop for creating the plots you may add another loop removing the unused axes.

    for j in range(len(locations), ncols*nrows):
        axs[j].axis("off")
    

Here is a complete example (where I borrowed the cod to generate a dataframe from @Robbie):

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

days = ['Mon','Tue','Wed','Thurs','Fri','Sat','Sun']
names = ["Parkhaus {:02}".format(i+1) for i in range(22)]

nItems = 1000

df = pd.DataFrame()
df['name'] = [names[i] for i in np.random.randint(0,len(names),nItems)]
df['openLots'] = np.random.randint(0,100,nItems)
df['occupationRatio'] = np.random.rand(nItems)
df['DoW'] = [days[i] for i in np.random.randint(0,7,nItems)]
df['Hour'] = np.random.randint(0,12,nItems)

df2 = df[['name','openLots','occupationRatio','DoW','Hour']]
nrows = 4; ncols=6
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15,9), sharey=True)
axs = axs.flatten()
locations = df2['name'].sort_values().unique()


def occupation_heatmap (name, ax, cbar=False, ylabel=False):
    dfn = df2[df2['name'] == name]
    dfn = dfn.groupby(['DoW', 'Hour']).mean()['occupationRatio'].unstack()
    dfn = dfn.reindex(['Mon', 'Tue', 'Wed','Thu','Fri','Sat','Sun'])
    sns.heatmap(data=dfn, cmap="coolwarm", vmin=0, vmax=1.0, ax=ax, cbar=cbar)
    ax.set_title(name)
    plt.setp(ax.get_yticklabels(), rotation=0)
    if not ylabel: ax.set_ylabel("")


for i, n in enumerate(locations): 
    occupation_heatmap (n, axs[i], cbar=i%ncols==ncols-1, ylabel=i%ncols==0)
for j in range(len(locations), ncols*nrows):
    axs[j].axis("off")

plt.tight_layout()
plt.show()

enter image description here

Upvotes: 4

Robbie
Robbie

Reputation: 4882

You can be more flexible and just create an axis for each name present, something like this:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import string

days = ['Mon','Tue','Wed','Thurs','Fri','Sat','Sun']
names = [string.lowercase[i] for i in range(22)]

nItems = 1000

df = pd.DataFrame()
df['name'] = [names[i] for i in np.random.randint(0,len(names),nItems)]
df['openLots'] = np.random.randint(0,100,nItems)
df['occupationRatio'] = np.random.randint(0,100,nItems)
df['DoW'] = [days[i] for i in np.random.randint(0,7,nItems)]
df['Hour'] = np.random.randint(0,12,nItems)




fig = plt.figure(figsize=(12,12))
for index, name in enumerate(names):
    ax = fig.add_subplot(4,6,index+1)
    dfn = df.loc[df.name==name]
    dfn = dfn.groupby(['DoW','Hour']).mean()['occupationRatio'].unstack()
    dfn = dfn.reindex(days)

    # Now we can operate on each plot axis individually
    if index%6!=5: #i.e.
        # Don't draw a colorbar
        sns.heatmap(data = dfn, cmap='coolwarm', ax=ax, cbar=False)
    else:
        sns.heatmap(data = dfn, cmap='coolwarm', ax=ax)

    if index%6!=0:
        # Remove the y-axis label
        ax.set_ylabel('')
        ax.set_yticks(())

    ax.set_title(name)

fig.tight_layout()
fig.show()

Results in: enter image description here You could also play around with the x-axes (for example remove labels and ticks except for the bottom row).

Upvotes: 2

Related Questions