Reputation: 61967
I am attempting to recreate the following plot from the book Introduction to Statistical learning using seaborn
I specifically want to recreate this using seaborn's lmplot
to create the first two plots and boxplot
to create the second. The main problem is that lmplot
creates a FacetGrid
according to this answer which forces me to hackily add another matplotlib Axes for the boxplot. I was wondering if there was an easier way to achieve this. Below, I have to do quite a bit of manual manipulation to get the desired plot.
seaborn_grid = sns.lmplot('value', 'wage', col='variable', hue='education', data=df_melt, sharex=False)
seaborn_grid.fig.set_figwidth(8)
left, bottom, width, height = seaborn_grid.fig.axes[0]._position.bounds
left2, bottom2, width2, height2 = seaborn_grid.fig.axes[1]._position.bounds
left_diff = left2 - left
seaborn_grid.fig.add_axes((left2 + left_diff, bottom, width, height))
sns.boxplot('education', 'wage', data=df_wage, ax = seaborn_grid.fig.axes[2])
ax2 = seaborn_grid.fig.axes[2]
ax2.set_yticklabels([])
ax2.set_xticklabels(ax2.get_xmajorticklabels(), rotation=30)
ax2.set_ylabel('')
ax2.set_xlabel('');
leg = seaborn_grid.fig.legends[0]
leg.set_bbox_to_anchor([0, .1, 1.5,1])
Sample data for DataFrames:
df_melt = {
'education': ['1. < HS Grad', '4. College Grad', '3. Some College', '4. College Grad', '2. HS Grad'],
'value': [18, 24, 45, 43, 50],
'variable': ['age', 'age', 'age', 'age', 'age'],
'wage': [75.0431540173515, 70.47601964694451, 130.982177377461, 154.68529299563, 75.0431540173515]}
df_wage = {
'education': ['1. < HS Grad', '4. College Grad', '3. Some College', '4. College Grad', '2. HS Grad'],
'wage': [75.0431540173515, 70.47601964694451, 130.982177377461, 154.68529299563, 75.0431540173515]}
Upvotes: 83
Views: 251925
Reputation: 2951
sns.regplot
, sns.boxplot
) in one Matplotlib figure (i.e., fig, axs = plt.subplots
)Building off of the suggestion of using two sns.regplot
's instead of sns.lmplot
in the accepted answer, here is a fully fleshed-out example closely mirroring the reference figure provided in your question.
The figure above was produced from the following code:
import matplotlib.pyplot as plt
import seaborn as sns
fig, axs = plt.subplots(ncols=3, sharey=True, figsize=(18, 6), dpi=300)
# Plots 1 & 2: Polynomial & Linear Regressions over "Wage" for "Age" and "Year"
for i, (variate, order) in enumerate(zip(["Age", "Year"], [2, 1])):
sns.stripplot(
x=variate,
y="Wage",
#hue="Education",
data=df,
ax=axs[i],
native_scale=True,
color="gray",
zorder=1,
alpha=0.5,
legend=False,
)
sns.regplot(
x=variate,
y="Wage",
data=df,
ax=axs[i],
scatter=False,
truncate=False,
order=order,
color="deepskyblue",
)
axs[i].set_xlabel(f"{variate}", labelpad=25, fontsize=18)
# Plot 3: Boxplot of "Wage" by "Education"
sns.boxplot(
x="Education",
y="Wage",
data=df,
hue="Education",
#palette="Set2",
ax=axs[2],
legend=True,
)
# Adjust axes labels for better readability
axs[0].set_ylabel("Wage", labelpad=25, fontsize=18)
axs[2].set_xlabel("Education", labelpad=25, fontsize=18)
for ax in axs:
ax.tick_params(
axis="both", which="major", labelsize=12, length=5, width=1.0
)
# Customize boxplot axes ticks and legend
axs[2].set_xticks(np.arange(len(education_levels)))
axs[2].set_xticklabels(
[label.split()[0][0] for label in sorted(education_levels)]
)
axs[2].legend(
loc="center left", bbox_to_anchor=(1, 0.5), title="Education Level"
)
plt.subplots_adjust(wspace=5.0)
plt.tight_layout()
plt.show()
using data simulated via:
import pandas as pd
import numpy as np
# Simulate example data with correlation and meaningful education levels
np.random.seed(0)
# Set parameter values
n_samples = 1000
mean_wage = 120
std_dev_wage = 60
age_min = 16
age_max = 80
peak_age = 50
age_data = np.random.uniform(age_min, age_max, n_samples)
# Generate corresponding wage data mimicking trend in given example
wage_data = np.zeros_like(age_data)
for i, age in enumerate(age_data):
base = (age / 100) * 90
if age <= peak_age:
m, s = list(
map(lambda p: p * (age / peak_age), (mean_wage, std_dev_wage))
)
wage_data[i] = np.abs(np.random.normal(m + base, s))
else:
m, s = list(
map(
lambda p: p * ((100 - age) / (100 - peak_age)),
(mean_wage, std_dev_wage),
)
)
wage_data[i] = np.abs(np.random.normal(m + base, s))
education_levels = [
"1. < HS Grad",
"2. HS Grad",
"3. Some College",
"4. College Grad",
"5. Postgraduate",
]
# Assign education levels vs. age by weighted probabilities
def assign_education(age):
education_levels = [
"1. < HS Grad",
"2. HS Grad",
"3. Some College",
"4. College Grad",
"5. Postgraduate",
]
if age >= 60:
weights = [0.05, 0.35, 0.25, 0.3, 0.05]
elif 45 <= age < 60:
weights = [0.05, 0.25, 0.25, 0.35, 0.1]
elif 25 <= age < 45:
weights = [0.1, 0.1, 0.3, 0.3, 0.2]
else:
weights = [0.2, 0.39, 0.3, 0.1, 0.01]
return np.random.choice(education_levels, p=weights)
education_data = np.array([assign_education(age) for age in age_data])
df = pd.DataFrame(
{
"Education": education_data,
"Age": age_data,
"Year": year_data,
"Wage": wage_data,
}
)
# Sort education by categories
df["Education"] = df["Education"].astype("category")
df["Education"] = df["Education"].cat.reorder_categories(
sorted(education_levels), ordered=True
)
print(f"DataFrame:\n{'-'*50}\n{df}\n")
print(f"DataFrame column datatypes:\n{'-'*50}\n{df.dtypes}\n")
print(
f"DataFrame 'Education' category order:\n{'-'*50}\n{df.Education.values}"
)
DataFrame:
--------------------------------------------------
Education Age Year Wage
0 4. College Grad 51.124064 2016 157.349244
1 4. College Grad 61.772119 2022 148.226233
2 4. College Grad 54.576856 2023 258.951815
3 3. Some College 50.872524 2019 151.065454
4 2. HS Grad 43.113907 2022 116.458425
.. ... ... ... ...
995 2. HS Grad 22.251288 2022 73.171386
996 2. HS Grad 48.955021 2022 55.975291
997 4. College Grad 76.058369 2016 102.863747
998 3. Some College 30.633379 2022 108.192692
999 4. College Grad 59.337033 2018 214.298984
[1000 rows x 4 columns]
DataFrame column datatypes:
--------------------------------------------------
Education category
Age float64
Year int64
Wage float64
dtype: object
DataFrame 'Education' category order:
--------------------------------------------------
['4. College Grad', '5. Postgraduate', '4. College Grad', '3. Some College', '2. HS Grad', ..., '2. HS Grad', '3. Some College', '5. Postgraduate', '3. Some College', '4. College Grad']
Length: 1000
Categories (5, object): ['1. < HS Grad' < '2. HS Grad' < '3. Some College' < '4. College Grad' < '5. Postgraduate']
Upvotes: 2
Reputation: 23111
As of seaborn 0.13.0 (over 7 years after this question was posted), it's still really difficult to add subplots to a seaborn figure-level objects without messing with the underlying figure positions. In fact, the method shown in the OP is probably the most readable way to do it.
With that being said, as suggested by Diziet Asahi, if you want to forego seaborn FacetGrids (e.g. lmplot
, catplot
etc.) altogether and use seaborn Axes-level methods to create an equivalent figure (e.g. regplot
instead of lmplot
, scatterplot
+lineplot
instead of relplot
etc.) and add more subplots such as boxplot
to the figure, you could group your data by the columns you were going to use as cols
kwarg in lmplot
(and groupby
the sub-dataframe by the columns you were going to use as hue
kwarg) and draw the plots using data from the sub-dataframes.
As an example, using the data in the OP, we could the following, which creates a somewhat equivalent figure to lmplot
but adds boxplot on the right:
# groupby data since `cols='variable'`
groupby_object = df_melt.groupby('variable')
# count number of groups to determine the required number of subplots
number_of_columns = groupby_object.ngroups
fig, axs = plt.subplots(1, number_of_columns+1, sharey=True)
for i, (_, g) in enumerate(groupby_object):
# feed data from each sub-dataframe `g` to regplot
sns.regplot(data=g, x='value', y='wage', ax=axs[i])
# plot the boxplot in the end
sns.boxplot(data=df_wage, x='education', y='wage', hue='education', ax=axs[-1])
The example in the OP uses hue=
kwarg to draw different lines of fit by 'education'
. To do that, we could groupby the sub-dataframe by the 'education'
column again and plot multiple regplots by education on the same Axes. A working example is as follows:
groupby_object = df_melt.groupby('variable')
number_of_columns = groupby_object.ngroups
fig, axs = plt.subplots(1, number_of_columns+1, figsize=(12, 5), sharey=True)
for i, (_, g) in enumerate(groupby_object):
for label, g1 in g.groupby('education'):
label = label if i == 0 else None
sns.regplot(data=g1, x='value', y='wage', label=label, scatter_kws={'alpha': 0.7}, ax=axs[i])
sns.boxplot(data=df_wage, x='education', y='wage', hue='education', ax=axs[-1])
axs[-1].set(ylabel='', xlabel='')
axs[-1].tick_params(axis='x', labelrotation=30)
for ax, title in zip(axs, ['Age', 'Year', 'Education']):
ax.set_title(title)
_ = fig.legend(bbox_to_anchor=(0.92, 0.5), loc="center left")
Using the following sample dataset (I had to create a new dataset since OP's sample is not rich enough to make a proper graph):
import numpy as np
import pandas as pd
rng = np.random.default_rng(0)
edu = rng.choice(['1. < HS Grad', '4. College Grad', '3. Some College', '4. College Grad','2. HS Grad'], size=100)
wage = rng.normal(75, 25, 100)
df_melt = pd.DataFrame({'education': edu, 'value': rng.normal(30, 20, 100), 'variable': rng.choice(['age', 'year'], 100), 'wage': wage})
df_wage = pd.DataFrame({'education': edu, 'wage': wage})
the above code plots the following figure:
Upvotes: 2
Reputation: 40697
One possibility would be to NOT use lmplot()
, but directly use regplot()
instead. regplot()
plots on the axes you pass as an argument with ax=
.
You lose the ability to automatically split your dataset according to a certain variable, but if you know beforehand the plots you want to generate, it shouldn't be a problem.
Something like this:
import matplotlib.pyplot as plt
import seaborn as sns
fig, axs = plt.subplots(ncols=3)
sns.regplot(x='value', y='wage', data=df_melt, ax=axs[0])
sns.regplot(x='value', y='wage', data=df_melt, ax=axs[1])
sns.boxplot(x='education',y='wage', data=df_melt, ax=axs[2])
Upvotes: 174