Peteris
Peteris

Reputation: 3756

How to plot a pandas DataFrame with multiple axes each rendering multiple columns?

I am seeking a function that would work as follows:

import pandas as pd

def plot_df(df: pd.DataFrame, x_column: str, columns: List[List[str]]):
  """Plot DataFrame using `x_column` on the x-axis and `len(columns)` different
  y-axes where the axis numbered `i` is calibrated to render the columns in `columns[i]`.

  Important: only 1 legend exists for the plot
  Important: each column has a distinct color
    If you wonder what colors axes should have, they can assume one of the line colors and just have a label associated (e.g., one axis for price, another for returns, another for growth)
"""

As an example, for a DataFrame with the columns time, price1, price2, returns, growth you could call it like so:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

This would result in a chart with:

I've looked at a couple of solutions which don't work for this.

Possible solution #1:

https://matplotlib.org/stable/gallery/ticks_and_spines/multiple_yaxis_with_spines.html

In this example, each axis can only accommodate one column, so it's wrong. In particular in the following code, each axis has one series:

p1, = ax.plot([0, 1, 2], [0, 1, 2], "b-", label="Density")
p2, = twin1.plot([0, 1, 2], [0, 3, 2], "r-", label="Temperature")
p3, = twin2.plot([0, 1, 2], [50, 30, 15], "g-", label="Velocity")

If you add another plot to this axis, the same color ends up duplicated:

enter image description here

Moreover, this version does not use the built in plot() function of data frames.

Possible solution #2:

PANDAS plot multiple Y axes

In this example, also each axis can only accommodate a single column from the data frame.

Possible solution #3:

Try to adapt solution 2. by changing df.A to df[['A', 'B']] but this beautifully doesn't work since it results in these 2 columns sharing the same axis color as well as multiple legends popping up.

So - asking pandas/matplotlib experts if you can figure out how to overcome this!

Upvotes: 2

Views: 3763

Answers (3)

DanDevost
DanDevost

Reputation: 61

You can chain axes from df to df.

import pandas as pd
import numpy as np

Create the data and put it in a df.

x=np.arange(0,2*np.pi,0.01)
b=np.sin(x)
c=np.cos(x)*10
d=np.sin(x+np.pi/4)*100
e=np.sin(x+np.pi/3)*50
df = pd.DataFrame({'x':x,'y1':b,'y2':c,'y3':d,'y4':e})

Define first plot and subsequent axes

ax1 = df.plot(x='x',y='y1',legend=None,color='black',figsize=(10,8))
ax2 = ax1.twinx()
ax2.tick_params(axis='y', labelcolor='r')

ax3 = ax1.twinx()
ax3.spines['right'].set_position(('axes',1.15))
ax3.tick_params(axis='y', labelcolor='g')

ax4=ax1.twinx()
ax4.spines['right'].set_position(('axes',1.30))
ax4.tick_params(axis='y', labelcolor='b')

You can add as many as you want...

Plot the remainder.

df.plot(x='x',y='y2',ax=ax2,color='r',legend=None)
df.plot(x='x',y='y3',ax=ax3,color='g',legend=None)
df.plot(x='x',y='y4',ax=ax4,color='b',legend=None)

Results:

Upvotes: 2

Zephyr
Zephyr

Reputation: 12524

I assume you are working with a dataframe like this:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

df = pd.DataFrame({'time': pd.date_range(start = '2020-01-01', end = '2020-01-10', freq = 'D')})
df['price1'] = np.random.random(len(df))
df['price2'] = np.random.random(len(df))
df['returns'] = np.random.random(len(df))
df['growth'] = np.random.random(len(df))
        time    price1    price2   returns    growth
0 2020-01-01  0.374540  0.020584  0.611853  0.607545
1 2020-01-02  0.950714  0.969910  0.139494  0.170524
2 2020-01-03  0.731994  0.832443  0.292145  0.065052
3 2020-01-04  0.598658  0.212339  0.366362  0.948886
4 2020-01-05  0.156019  0.181825  0.456070  0.965632
5 2020-01-06  0.155995  0.183405  0.785176  0.808397
6 2020-01-07  0.058084  0.304242  0.199674  0.304614
7 2020-01-08  0.866176  0.524756  0.514234  0.097672
8 2020-01-09  0.601115  0.431945  0.592415  0.684233
9 2020-01-10  0.708073  0.291229  0.046450  0.440152

Then a possible function could be:

def plot_df(df, x_column, columns):

    cmap = cm.get_cmap('tab10', 10)

    fig, ax = plt.subplots()

    axes = [ax]
    handles = []

    for i, _ in enumerate(range(len(columns) - 1)):
        twin = ax.twinx()
        axes.append(twin)
        twin.spines.right.set_position(("axes", 1 + i/10))

    j = 0
    for i, col in enumerate(columns):
        ylabel = []
        if len(col) == 1:
            p, = axes[i].plot(df[x_column], df[col[0]], label = col[0], color = cmap(j)[:3])
            ylabel.append(col[0])
            handles.append(p)
            j += 1
        else:

            for sub_col in col:
                p, = axes[i].plot(df[x_column], df[sub_col], label = sub_col, color = cmap(j)[:3])
                ylabel.append(sub_col)
                handles.append(p)
                j += 1
        axes[i].set_ylabel(', '.join(ylabel))

    ax.legend(handles = handles, frameon = True)

    plt.tight_layout()

    plt.show()

If you call the above function with:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

then you will get:

enter image description here

NOTES

The first element of columns list (['price1', 'price2'] in this case) is always drawn on the left axis, other elements on the right ones.

Upvotes: 1

Zephyr
Zephyr

Reputation: 12524

I assume you are working with a dataframe like this:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

df = pd.DataFrame({'time': pd.date_range(start = '2020-01-01', end = '2020-01-10', freq = 'D')})
df['price1'] = np.random.random(len(df))
df['price2'] = np.random.random(len(df))
df['returns'] = np.random.random(len(df))
df['growth'] = np.random.random(len(df))
        time    price1    price2   returns    growth
0 2020-01-01  0.374540  0.020584  0.611853  0.607545
1 2020-01-02  0.950714  0.969910  0.139494  0.170524
2 2020-01-03  0.731994  0.832443  0.292145  0.065052
3 2020-01-04  0.598658  0.212339  0.366362  0.948886
4 2020-01-05  0.156019  0.181825  0.456070  0.965632
5 2020-01-06  0.155995  0.183405  0.785176  0.808397
6 2020-01-07  0.058084  0.304242  0.199674  0.304614
7 2020-01-08  0.866176  0.524756  0.514234  0.097672
8 2020-01-09  0.601115  0.431945  0.592415  0.684233
9 2020-01-10  0.708073  0.291229  0.046450  0.440152

Then a possible function could be:

def plot_df(df, x_column, columns):

    cmap = cm.get_cmap('tab10', 10)
    line_styles = ["-", "--", "-.", ":"]

    fig, ax = plt.subplots()

    axes = [ax]
    handles = []

    for i, _ in enumerate(range(len(columns) - 1)):
        twin = ax.twinx()
        axes.append(twin)
        twin.spines.right.set_position(("axes", 1 + i/10))

    for i, col in enumerate(columns):
        if len(col) == 1:
            p, = axes[i].plot(df[x_column], df[col[0]], label = col[0], color = cmap(i)[:3])
            handles.append(p)
        else:
            for j, sub_col in enumerate(col):
                p, = axes[i].plot(df[x_column], df[sub_col], label = sub_col, color = cmap(i)[:3], linestyle = line_styles[j])
                handles.append(p)

    ax.legend(handles = handles, frameon = True)

    for i, ax in enumerate(axes):
        ax.tick_params(axis = 'y', colors = cmap(i)[:3])
        if i == 0:
            ax.spines['left'].set_color(cmap(i)[:3])
            ax.spines['right'].set_visible(False)
        else:
            ax.spines['left'].set_visible(False)
            ax.spines['right'].set_color(cmap(i)[:3])

    plt.tight_layout()

    plt.show()

If you call the above function with:

plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])

then you will get:

enter image description here

NOTES

  1. since price1 and price2 share the same y axis, they must share the same color too, so I have to use different linestyle to be able to distinguish them.
  2. the first element of columns list (['price1', 'price2'] in this case) is always drawn on the left axis, other elements on the right ones.
  3. if you wanted to set axis limits and labels, then you should pass these as additional parameters to plot_df.

Upvotes: 1

Related Questions