Reputation: 1020
I am trying to visualize different type of "purchases" over a quarterly period for selected customers. To generate this visual, I am using a catplot functionality in seaborn but am unable to add a horizontal line that connects each of the purchased fruits. Each line should start at the first dot for each fruit and end at the last dot for the same fruit. Any ideas on how to do this programmatically?
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
dta = pd.DataFrame(columns=["Date", "Fruit", "type"], data=[['2017-01-01','Orange',
'FP'], ['2017-04-01','Orange', 'CP'], ['2017-07-01','Orange', 'CP'],
['2017-10-08','Orange', 'CP'],['2017-01-01','Apple', 'NP'], ['2017-04-01','Apple', 'CP'],
['2017-07-01','Banana', 'NP'], ['2017-10-08','Orange', 'CP']
])
dta['quarter'] = pd.PeriodIndex(dta.Date, freq='Q')
sns.catplot(x="quarter", y="Fruit", hue="type", kind="swarm", data=dta)
plt.show()
This is the result:
How can I add individual horizontal lines that each connect the dots for purchases of orange and apple?
Upvotes: 2
Views: 1161
Reputation: 41327
Each line should start at the first dot for each fruit and end at the last dot for the same fruit.
groupby.ngroup
to map the quarters to xtick positionsgroupby.agg
to find each fruit's min and max xtick endpointsax.hlines
to plot horizontal lines from each fruit's min to maxdf = pd.DataFrame([['2017-01-01', 'Orange', 'FP'], ['2017-04-01', 'Orange', 'CP'], ['2017-07-01', 'Orange', 'CP'], ['2017-10-08', 'Orange', 'CP'], ['2017-01-01', 'Apple', 'NP'], ['2017-04-01', 'Apple', 'CP'], ['2017-07-01', 'Banana', 'NP'], ['2017-10-08', 'Orange', 'CP']], columns=['Date', 'Fruit', 'type'])
df['quarter'] = pd.PeriodIndex(df['Date'], freq='Q')
df = df.sort_values('quarter') # sort dataframe by quarter
df['xticks'] = df.groupby('quarter').ngroup() # map quarter to xtick position
ends = df.groupby('Fruit')['xticks'].agg(['min', 'max']) # find min and max xtick per fruit
g = sns.catplot(x='quarter', y='Fruit', hue='type', kind='swarm', s=8, data=df)
g.axes[0, 0].hlines(ends.index, ends['min'], ends['max']) # plot horizontal lines from each fruit's min to max
catplot
plots the xticks in the order they appear in the dataframe. The sample dataframe is already sorted by quarter
, but the real dataframe should be sorted explicitly:
df = df.sort_values('quarter')
Map the quarters to their xtick positions using groupby.ngroup
:
df['xticks'] = df.groupby('quarter').ngroup()
# Date Fruit type quarter xticks
# 0 2017-01-01 Orange FP 2017Q1 0
# 1 2017-04-01 Orange CP 2017Q2 1
# 2 2017-07-01 Orange CP 2017Q3 2
# 3 2017-10-08 Orange CP 2017Q4 3
# 4 2017-01-01 Apple NP 2017Q1 0
# 5 2017-04-01 Apple CP 2017Q2 1
# 6 2017-07-01 Banana NP 2017Q3 2
# 7 2017-10-08 Orange CP 2017Q4 3
Find the min and max xticks
to get the endpoints per Fruit
using groupby.agg
:
ends = df.groupby('Fruit')['xticks'].agg(['min', 'max'])
# min max
# Fruit
# Apple 0 1
# Banana 2 2
# Orange 0 3
Use ax.hlines
to plot a horizontal line per Fruit
from min-endpoint to max-endpoint:
g = sns.catplot(x='quarter', y='Fruit', hue='type', kind='swarm', s=8, data=df)
ax = g.axes[0, 0]
ax.hlines(ends.index, ends['min'], ends['max'])
Upvotes: 2
Reputation: 353
You just need to enable the horizontal grid for the chart as follows:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
dta = pd.DataFrame(
columns=["Date", "Fruit", "type"],
data=[
["2017-01-01", "Orange", "FP"],
["2017-04-01", "Orange", "CP"],
["2017-07-01", "Orange", "CP"],
["2017-10-08", "Orange", "CP"],
["2017-01-01", "Apple", "NP"],
["2017-04-01", "Apple", "CP"],
["2017-07-01", "Banana", "NP"],
["2017-10-08", "Orange", "CP"],
],
)
dta["quarter"] = pd.PeriodIndex(dta.Date, freq="Q")
sns.catplot(x="quarter", y="Fruit", hue="type", kind="swarm", data=dta)
plt.grid(axis='y')
plt.show()
Preview
Upvotes: 1