Reputation: 3756
I am seeking a function that would work as follows:
import pandas as pd
def plot_df(df: pd.DataFrame, x_column: str, columns: List[List[str]]):
"""Plot DataFrame using `x_column` on the x-axis and `len(columns)` different
y-axes where the axis numbered `i` is calibrated to render the columns in `columns[i]`.
Important: only 1 legend exists for the plot
Important: each column has a distinct color
If you wonder what colors axes should have, they can assume one of the line colors and just have a label associated (e.g., one axis for price, another for returns, another for growth)
"""
As an example, for a DataFrame with the columns time, price1, price2, returns, growth
you could call it like so:
plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])
This would result in a chart with:
price1
and price2
would be sharedI've looked at a couple of solutions which don't work for this.
Possible solution #1:
https://matplotlib.org/stable/gallery/ticks_and_spines/multiple_yaxis_with_spines.html
In this example, each axis can only accommodate one column, so it's wrong. In particular in the following code, each axis has one series:
p1, = ax.plot([0, 1, 2], [0, 1, 2], "b-", label="Density")
p2, = twin1.plot([0, 1, 2], [0, 3, 2], "r-", label="Temperature")
p3, = twin2.plot([0, 1, 2], [50, 30, 15], "g-", label="Velocity")
If you add another plot to this axis, the same color ends up duplicated:
Moreover, this version does not use the built in plot()
function of data frames.
Possible solution #2:
In this example, also each axis can only accommodate a single column from the data frame.
Possible solution #3:
Try to adapt solution 2. by changing df.A
to df[['A', 'B']]
but this beautifully doesn't work since it results in these 2 columns sharing the same axis color as well as multiple legends popping up.
So - asking pandas/matplotlib experts if you can figure out how to overcome this!
Upvotes: 2
Views: 3763
Reputation: 61
You can chain axes from df to df.
import pandas as pd
import numpy as np
Create the data and put it in a df.
x=np.arange(0,2*np.pi,0.01)
b=np.sin(x)
c=np.cos(x)*10
d=np.sin(x+np.pi/4)*100
e=np.sin(x+np.pi/3)*50
df = pd.DataFrame({'x':x,'y1':b,'y2':c,'y3':d,'y4':e})
Define first plot and subsequent axes
ax1 = df.plot(x='x',y='y1',legend=None,color='black',figsize=(10,8))
ax2 = ax1.twinx()
ax2.tick_params(axis='y', labelcolor='r')
ax3 = ax1.twinx()
ax3.spines['right'].set_position(('axes',1.15))
ax3.tick_params(axis='y', labelcolor='g')
ax4=ax1.twinx()
ax4.spines['right'].set_position(('axes',1.30))
ax4.tick_params(axis='y', labelcolor='b')
You can add as many as you want...
Plot the remainder.
df.plot(x='x',y='y2',ax=ax2,color='r',legend=None)
df.plot(x='x',y='y3',ax=ax3,color='g',legend=None)
df.plot(x='x',y='y4',ax=ax4,color='b',legend=None)
Upvotes: 2
Reputation: 12524
I assume you are working with a dataframe like this:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
df = pd.DataFrame({'time': pd.date_range(start = '2020-01-01', end = '2020-01-10', freq = 'D')})
df['price1'] = np.random.random(len(df))
df['price2'] = np.random.random(len(df))
df['returns'] = np.random.random(len(df))
df['growth'] = np.random.random(len(df))
time price1 price2 returns growth
0 2020-01-01 0.374540 0.020584 0.611853 0.607545
1 2020-01-02 0.950714 0.969910 0.139494 0.170524
2 2020-01-03 0.731994 0.832443 0.292145 0.065052
3 2020-01-04 0.598658 0.212339 0.366362 0.948886
4 2020-01-05 0.156019 0.181825 0.456070 0.965632
5 2020-01-06 0.155995 0.183405 0.785176 0.808397
6 2020-01-07 0.058084 0.304242 0.199674 0.304614
7 2020-01-08 0.866176 0.524756 0.514234 0.097672
8 2020-01-09 0.601115 0.431945 0.592415 0.684233
9 2020-01-10 0.708073 0.291229 0.046450 0.440152
Then a possible function could be:
def plot_df(df, x_column, columns):
cmap = cm.get_cmap('tab10', 10)
fig, ax = plt.subplots()
axes = [ax]
handles = []
for i, _ in enumerate(range(len(columns) - 1)):
twin = ax.twinx()
axes.append(twin)
twin.spines.right.set_position(("axes", 1 + i/10))
j = 0
for i, col in enumerate(columns):
ylabel = []
if len(col) == 1:
p, = axes[i].plot(df[x_column], df[col[0]], label = col[0], color = cmap(j)[:3])
ylabel.append(col[0])
handles.append(p)
j += 1
else:
for sub_col in col:
p, = axes[i].plot(df[x_column], df[sub_col], label = sub_col, color = cmap(j)[:3])
ylabel.append(sub_col)
handles.append(p)
j += 1
axes[i].set_ylabel(', '.join(ylabel))
ax.legend(handles = handles, frameon = True)
plt.tight_layout()
plt.show()
If you call the above function with:
plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])
then you will get:
NOTES
The first element of columns list (['price1', 'price2']
in this case) is always drawn on the left axis, other elements on the right ones.
Upvotes: 1
Reputation: 12524
I assume you are working with a dataframe like this:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
df = pd.DataFrame({'time': pd.date_range(start = '2020-01-01', end = '2020-01-10', freq = 'D')})
df['price1'] = np.random.random(len(df))
df['price2'] = np.random.random(len(df))
df['returns'] = np.random.random(len(df))
df['growth'] = np.random.random(len(df))
time price1 price2 returns growth
0 2020-01-01 0.374540 0.020584 0.611853 0.607545
1 2020-01-02 0.950714 0.969910 0.139494 0.170524
2 2020-01-03 0.731994 0.832443 0.292145 0.065052
3 2020-01-04 0.598658 0.212339 0.366362 0.948886
4 2020-01-05 0.156019 0.181825 0.456070 0.965632
5 2020-01-06 0.155995 0.183405 0.785176 0.808397
6 2020-01-07 0.058084 0.304242 0.199674 0.304614
7 2020-01-08 0.866176 0.524756 0.514234 0.097672
8 2020-01-09 0.601115 0.431945 0.592415 0.684233
9 2020-01-10 0.708073 0.291229 0.046450 0.440152
Then a possible function could be:
def plot_df(df, x_column, columns):
cmap = cm.get_cmap('tab10', 10)
line_styles = ["-", "--", "-.", ":"]
fig, ax = plt.subplots()
axes = [ax]
handles = []
for i, _ in enumerate(range(len(columns) - 1)):
twin = ax.twinx()
axes.append(twin)
twin.spines.right.set_position(("axes", 1 + i/10))
for i, col in enumerate(columns):
if len(col) == 1:
p, = axes[i].plot(df[x_column], df[col[0]], label = col[0], color = cmap(i)[:3])
handles.append(p)
else:
for j, sub_col in enumerate(col):
p, = axes[i].plot(df[x_column], df[sub_col], label = sub_col, color = cmap(i)[:3], linestyle = line_styles[j])
handles.append(p)
ax.legend(handles = handles, frameon = True)
for i, ax in enumerate(axes):
ax.tick_params(axis = 'y', colors = cmap(i)[:3])
if i == 0:
ax.spines['left'].set_color(cmap(i)[:3])
ax.spines['right'].set_visible(False)
else:
ax.spines['left'].set_visible(False)
ax.spines['right'].set_color(cmap(i)[:3])
plt.tight_layout()
plt.show()
If you call the above function with:
plot_df(df, 'time', [['price1', 'price2'], ['returns'], ['growth']])
then you will get:
NOTES
price1
and price2
share the same y axis, they must share the same color too, so I have to use different linestyle
to be able to distinguish them.columns
list (['price1', 'price2']
in this case) is always drawn on the left axis, other elements on the right ones.plot_df
.Upvotes: 1