klho
klho

Reputation: 33

How can I create a Seaborn line plot with 3 different y-axis?

I am trying to plot 3 different scales on the y-axis with 3 different sets of data. I am able to plot the 3rd line but the y2 and y3 axis are together.

I need to separate these two axes so that they are readable.

enter image description here

Can this be done with Seaborn Library?

This is the code:

import datetime
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Ingest the data
url = 'https://covid.ourworldindata.org/data/owid-covid-data.csv'
covid_data = pd.read_csv(url).set_index("location")

# Clean the data 
df = covid_data.copy()
df.date = pd.to_datetime(df.date)
df = df.loc[df['date'] > (datetime.datetime(2021, 4, 30)), :]
df = df[df.index.isin(['United States'])]

# Select the features of interest
new_cases = 'new_cases_smoothed_per_million'
patients = 'hosp_patients_per_million'
vaccinated = 'people_fully_vaccinated_per_hundred'
bedsT = 'hospital_beds_per_thousand'
bedsM = 'hospital_beds_per_million'
beds_used = 'hospital_beds_used'
df = df.loc[:, ['date', new_cases, patients, vaccinated, bedsT]]
df[bedsM] = df[bedsT] * 1000
df[beds_used]=df.apply(lambda x: x[patients] / x[bedsM], axis = 1)


# Visualise the data
y1_color = "red"
y2_color = "green"
y3_color = "blue"

x1_axis = "date"
y1_axis = new_cases
y2_axis = vaccinated
y3_axis = beds_used

x1 = df[x1_axis]
y1 = df[y1_axis]
y2 = df[y2_axis]
y3 = df[y3_axis]
y2_limit = df[y2_axis].max()


fig, ax1 = plt.subplots(figsize=(16, 6))
ax1.set_title("United States")
ax2 = ax1.twinx()
ax3 = ax1.twinx()

ax2.set(ylim=(0, y2_limit))
g1 = sns.lineplot(data = df, x = x1, y = y1, ax = ax1, color = y1_color) # plots the first set
g2 = sns.lineplot(data = df, x = x1, y = y2, ax = ax2, color = y2_color) # plots the second set 
g3 = sns.lineplot(data = df, x = x1, y = y3, ax = ax3, color = y3_color) # plots the third set 

Upvotes: 2

Views: 3262

Answers (1)

r-beginners
r-beginners

Reputation: 35145

I modified the code with your data, referring to the subgraph in the official sample. You can find the reference here.

from mpl_toolkits.axes_grid1 import host_subplot
from mpl_toolkits import axisartist

# fig, ax1 = plt.subplots(figsize=(16, 6))
host = host_subplot(111, axes_class=axisartist.Axes) # update
plt.rcParams["figure.figsize"] = (16, 6) # update

ax1.set_title("United States")
# ax1 = host.twinx()
ax2 = host.twinx() # update
ax3 = host.twinx() # update

ax3.axis["right"] = ax3.new_fixed_axis(loc="right", offset=(50, 0)) # update

ax1.axis["right"].toggle(all=True) # update
ax2.axis["right"].toggle(all=True) # update

ax2.set(ylim=(0, y2_limit))
sns.lineplot(data = df, x = x1, y = y1, ax = host, color = y1_color) # plots the first set ax = ax1,
sns.lineplot(data = df, x = x1, y = y2, ax = ax2, color = y2_color) # plots the second set ax = ax2,
sns.lineplot(data = df, x = x1, y = y3, ax = ax3, color = y3_color) # plots the third set ax = ax3, 

plt.show()

enter image description here

Upvotes: 2

Related Questions