rbaleksandar
rbaleksandar

Reputation: 9681

How to shared color palette between multiple subplots?

I have the following figure:

enter image description here

The figure is composed by the following code snippet:

fig = plt.figure(constrained_layout=True)
grid = fig.add_gridspec(2, 2)

ax_samples_losses = fig.add_subplot(grid[0, 0:])
ax_samples_losses.set_title('Avg. loss per train sample (epoch 0 excluded)')
for sample_idx, sample_avg_train_loss_history in enumerate(samples_avg_train_loss_history):
    ax_samples_losses.plot(sample_avg_train_loss_history, label='Sample ' + str(sample_idx))
ax_samples_losses.set_title('Avg. loss per train sample (epoch 0 excluded)')
ax_samples_losses.set_xlabel('Epoch')
ax_samples_losses.set_ylabel('Sample avg. loss')
ax_samples_losses.set_xticks(range(1, epochs))
ax_samples_losses.tick_params(axis='x', rotation=90)
ax_samples_losses.yaxis.set_ticks(np.arange(0, np.max(samples_avg_train_loss_history), 0.25))
ax_samples_losses.tick_params(axis='both', which='major', labelsize=6)
plt.legend(bbox_to_anchor=(1, 1), prop={'size': 6}) #loc="upper left"
# fig.legend(...)

ax_patches_per_sample = fig.add_subplot(grid[1, 0])
#for sample_idx, sample_patches_count in enumerate(samples_train_patches_count):
#    ax_patches_per_sample.bar(sample_patches_count, label='Sample ' + str(sample_idx))
ax_patches_per_sample.bar(range(0, len(samples_train_patches_count)), samples_train_patches_count, align='center')
ax_patches_per_sample.set_title('Patches per sample')
ax_patches_per_sample.set_xlabel('Sample')
ax_patches_per_sample.set_ylabel('Patch count')
ax_patches_per_sample.set_xticks(range(0, len(samples_train_patches_count)))
ax_patches_per_sample.yaxis.set_ticks(np.arange(0, np.max(samples_train_patches_count), 20))
ax_patches_per_sample.tick_params(axis='both', which='major', labelsize=6)

where

I do believe I need to do both

The shared legend can be done by using get_legend_handles_labels(). However I do not know how to share colors. Both subplots describe different properties of the same thing - the samples. In short I would like to have Patches per sample subplot have all the colors Avg. loss per train sample (epoch 0 excluded) uses.

Upvotes: 0

Views: 1363

Answers (1)

Davide_sd
Davide_sd

Reputation: 13135

The first plot is using standard matplotlib Tab10 discrete color map. We can create a cycler over this colormap, and set one by one the color of each bar:

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.gridspec import GridSpec
import numpy as np
from itertools import cycle

# create a cycler to continously loop over a discrete colormap
cycler = cycle(cm.tab10.colors)

N = 10
x = np.arange(N).astype(int)
y = np.random.uniform(5, 15, N)

f = plt.figure()
gs = GridSpec(2, 4)
ax0 = f.add_subplot(gs[0, :-1])
ax1 = f.add_subplot(gs[1, :-1])
ax2 = f.add_subplot(gs[:, -1])

for i in x:
    ax0.plot(x, np.exp(-x / (i + 1)), label="Sample %s" % (i + 1))
h, l = ax0.get_legend_handles_labels()

ax1.bar(x, y)
for p in ax1.patches:
    p.set_facecolor(next(cycler))

ax2.axis(False)
ax2.legend(h, l)
plt.tight_layout()

enter image description here

EDIT to accommodate comment. To avoid repetitions you should use a colormap. Matplotlib offers many colormaps. Alternatively, you can also create your own.

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.gridspec import GridSpec
import numpy as np
from itertools import cycle

N = 50
# create a cycler to continously loop over a discrete colormap
colors = cm.viridis(np.linspace(0, 1, N))

x = np.arange(N).astype(int)
y = np.random.uniform(5, 15, N)

f = plt.figure()
gs = GridSpec(2, 4)
ax0 = f.add_subplot(gs[0, :-1])
ax1 = f.add_subplot(gs[1, :-1])
ax2 = f.add_subplot(gs[:, -1])

ax1.bar(x, y)

for i in x:
    c = next(cycler)
    ax0.plot(x, np.exp(-x / (i + 1)), color=c, label="Sample %s" % (i + 1))
    ax1.patches[i].set_facecolor(c)
h, l = ax0.get_legend_handles_labels()

ax2.axis(False)
ax2.legend(h, l)
plt.tight_layout()

enter image description here

Upvotes: 2

Related Questions