HW_Tang
HW_Tang

Reputation: 107

One colorbar to indicate data range for multiple subplots using matplotlib?

I have saw many similar questions like this one. However, the colorbar actually indicates the data range of last subplot, as is verified by the following code:

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(19680801)


fig, axs = plt.subplots(2, 1)
cmaps = ['RdBu_r', 'viridis']
for row in range(2):
    ax = axs[row]
    if row == 0:
        pcm = ax.pcolormesh(np.random.random((20, 20)) * (-100),
                            cmap=cmaps[0])
    elif row == 1:
            pcm = ax.pcolormesh(np.random.random((20, 20)) * 100,
                            cmap=cmaps[0])
fig.colorbar(pcm, ax=axs)
plt.show()

Output

The colobar only indicates the data range of second sub-figure. The data in the first sub-figure is actually negative, which is not shown in the colorbar.

fig, axs = plt.subplots(2, 1)
cmaps = ['RdBu_r', 'viridis']
for row in range(2):
    ax = axs[row]
    if row == 0:
        pcm = ax.pcolormesh(np.random.random((20, 20)) * (-100),
                            cmap=cmaps[0])
    elif row == 1:
            pcm = ax.pcolormesh(np.random.random((20, 20)) * 100,
                            cmap=cmaps[0])
    fig.colorbar(pcm, ax=ax)
plt.show()

output

So how to make one colorbar shared by multiple subplots to indicate overall data range ?

The problem may be cause by fig.colorbar(pcm, ax=axs), where pcm is pointed to the second sub-figure, but I am not sure how to solve this problem.

Upvotes: 3

Views: 4801

Answers (1)

Jody Klymak
Jody Klymak

Reputation: 5913

Set the color limits to be the same...

import matplotlib.pyplot as plt 
import numpy as np 

fig, axs = plt.subplots(2, 1)
cmaps = ['RdBu_r', 'viridis']
for row in range(2):
    ax = axs[row]
    mult = -100 if row == 0 else 100
    pcm = ax.pcolormesh(np.random.random((20, 20)) * mult,
                            cmap=cmaps[0], vmin=-150, vmax=150)
fig.colorbar(pcm, ax=axs)
plt.show()

or equivalently you can specify a Normalize object:

import matplotlib.pyplot as plt 
import numpy as np 


fig, axs = plt.subplots(2, 1)
cmaps = ['RdBu_r', 'viridis']
norm = plt.Normalize(vmin=-150, vmax=150)
for row in range(2):
    ax = axs[row]
    mult = -100 if row == 0 else 100
    pcm = ax.pcolormesh(np.random.random((20, 20)) * mult,
                            cmap=cmaps[0], norm=norm)
fig.colorbar(pcm, ax=axs)
plt.show()

Upvotes: 2

Related Questions