steeles
steeles

Reputation: 169

pandas, matplotlib: a way to assign same colors, line styles for same column labels across subplots?

I have a few data frames showing objects with some similar variables (ie, column names) that vary over time, and I'm plotting them in subplots.

>>df1.head()


         FR  stim_current  self_excitation    FF_inh  SFA
1  0.000000           0.0         0.000000 -0.075483   -0
2  0.000000           0.0         0.000000 -0.000000   -0
3 -0.000012           0.0         0.000000 -0.001761   -0
4 -0.000033           0.0        -0.000009 -0.003487    0
5 -0.000064           0.0        -0.000027 -0.005178    0

>>df2.head()

      FR    FB_inh  stim_current  self_excitation
1  0.000000 -0.001569             1         0.000000
2  0.017609 -0.000000             1         0.000000
3  0.034867 -0.000200             1         0.010037
4  0.051780 -0.000577             1         0.019874
5  0.068355 -0.001109             1         0.029515

Is there a way to assign a line style by column name, so that, for instance, FR, stim_current, and self_excitation would have the same colors in each subplot? Say I want FR to be blue and bold, stim current to be black, and self_excitation to be green. I'd also like whatever is different between data frames to show up in a different color on each subplot. Ideally I could also re-order the columns of the data frame so that things that show up in each data frame are all in the legend on top and the stuff that's varied would come in on the bottom of the legend.

Upvotes: 3

Views: 918

Answers (1)

rth
rth

Reputation: 11201

It is possible to use consistent colors and line-styles between different subplots with the following approach,

import matplotlib.pyplot as plt
import numpy as np

# load your pandas DataFrames  df1 and df2 here 

ax = [plt.subplot(211), plt.subplot(211)]
pars = {'FR': {'color': 'r'},
        'stim_current': {'color': 'k'}}
ls_style = ['dashed', 'solid']
for ax_idx, name in enumerate(['FR', 'stim_current']):
    for df, ls in zip([df1, df2], ls_style):
        ax[ax_idx].plot(df.index, df[name], ls=ls, **pars[name])

Upvotes: 3

Related Questions