JC_CL
JC_CL

Reputation: 2638

how to change the colors of multiple subplots at once?

I am looping through a bunch of CSV files containing various measurements.

Each file might be from one of 4 different data sources.

In each file, I merge the data into monthly datasets, that I then plot in a 3x4 grid. After this plot has been saved, the loop moves on and does the same to the next file.

This part I got figured out, however I would like to add a visual clue to the plots, as to what data it is. As far as I understand it (and tried it)

plt.subplot(4,3,1)
plt.hist(Jan_Data,facecolor='Red')
plt.ylabel('value count')
plt.title('January')

does work, however this way, I would have to add the facecolor='Red' by hand to every 12 subplots. Looping through the plots wont work for this situation, since I want the ylabel only for the leftmost plots, and xlabels for the bottom row.

Setting facecolor at the beginning in

fig = plt.figure(figsize=(20,15),facecolor='Red')

does not work, since it only changes the background color of the 20 by 15 figure now, which subsequently gets ignored when I save it to a PNG, since it only gets set for screen output.

So is there just a simple setthecolorofallbars='Red' command for plt.hist(… or plt.savefig(… I am missing, or should I just copy n' paste it to all twelve months?

Upvotes: 3

Views: 2352

Answers (2)

cel
cel

Reputation: 31399

You can use mpl.rc("axes", color_cycle="red") to set the default color cycle for all your axes.

In this little toy example, I use the with mpl.rc_context block to limit the effects of mpl.rc to just the block. This way you don't spoil the default parameters for your whole session.

import matplotlib as mpl
import matplotlib.pylab as plt
import numpy as np
np.random.seed(42)

# create some toy data
n, m = 2, 2
data = []
for i in range(n*m):
    data.append(np.random.rand(30))

# and do the plotting
with mpl.rc_context():
    mpl.rc("axes", color_cycle="red")
    fig, axes = plt.subplots(n, m, figsize=(8,8))
    for ax, d in zip(axes.flat, data):
        ax.hist(d)

plot

Upvotes: 3

plonser
plonser

Reputation: 3363

The problem with the x- and y-labels (when you use loops) can be solved by using plt.subplots as you can access every axis seperately.

import matplotlib.pyplot as plt
import numpy.random

# creating figure with 4 plots
fig,ax = plt.subplots(2,2)

# some data
data = numpy.random.randn(4,1000)

# some titles
title = ['Jan','Feb','Mar','April']

xlabel = ['xlabel1','xlabel2']
ylabel = ['ylabel1','ylabel2']

for i in range(ax.size):
    a = ax[i/2,i%2]
    a.hist(data[i],facecolor='r',bins=50)
    a.set_title(title[i])

# write the ylabels on all axis on the left hand side
for j in range(ax.shape[0]):
    ax[j,0].set_ylabel(ylabel[j])

# write the xlabels an all axis on the bottom
for j in range(ax.shape[1]):
    ax[-1,j].set_xlabel(xlabels[j])


fig.tight_layout()

All features (like titles) which are not constant can be put into arrays and placed at the appropriate axis.

Upvotes: 2

Related Questions