veg2020
veg2020

Reputation: 1020

Add horizontal lines from min point to max point per category in seaborn catplot

I am trying to visualize different type of "purchases" over a quarterly period for selected customers. To generate this visual, I am using a catplot functionality in seaborn but am unable to add a horizontal line that connects each of the purchased fruits. Each line should start at the first dot for each fruit and end at the last dot for the same fruit. Any ideas on how to do this programmatically?

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

dta = pd.DataFrame(columns=["Date", "Fruit", "type"], data=[['2017-01-01','Orange', 
'FP'], ['2017-04-01','Orange', 'CP'], ['2017-07-01','Orange', 'CP'], 
['2017-10-08','Orange', 'CP'],['2017-01-01','Apple', 'NP'], ['2017-04-01','Apple', 'CP'], 
['2017-07-01','Banana', 'NP'], ['2017-10-08','Orange', 'CP']
                                                        ])
dta['quarter'] = pd.PeriodIndex(dta.Date, freq='Q')

sns.catplot(x="quarter", y="Fruit", hue="type", kind="swarm", data=dta)

plt.show()

This is the result:

enter image description here.

How can I add individual horizontal lines that each connect the dots for purchases of orange and apple?

Upvotes: 2

Views: 1161

Answers (2)

tdy
tdy

Reputation: 41327

Each line should start at the first dot for each fruit and end at the last dot for the same fruit.

  1. Use groupby.ngroup to map the quarters to xtick positions
  2. Use groupby.agg to find each fruit's min and max xtick endpoints
  3. Use ax.hlines to plot horizontal lines from each fruit's min to max
df = pd.DataFrame([['2017-01-01', 'Orange', 'FP'], ['2017-04-01', 'Orange', 'CP'], ['2017-07-01', 'Orange', 'CP'], ['2017-10-08', 'Orange', 'CP'], ['2017-01-01', 'Apple', 'NP'], ['2017-04-01', 'Apple', 'CP'], ['2017-07-01', 'Banana', 'NP'], ['2017-10-08', 'Orange', 'CP']], columns=['Date', 'Fruit', 'type'])
df['quarter'] = pd.PeriodIndex(df['Date'], freq='Q')

df = df.sort_values('quarter')                            # sort dataframe by quarter
df['xticks'] = df.groupby('quarter').ngroup()             # map quarter to xtick position
ends = df.groupby('Fruit')['xticks'].agg(['min', 'max'])  # find min and max xtick per fruit

g = sns.catplot(x='quarter', y='Fruit', hue='type', kind='swarm', s=8, data=df)
g.axes[0, 0].hlines(ends.index, ends['min'], ends['max']) # plot horizontal lines from each fruit's min to max

Detailed breakdown:

  1. catplot plots the xticks in the order they appear in the dataframe. The sample dataframe is already sorted by quarter, but the real dataframe should be sorted explicitly:

    df = df.sort_values('quarter')
    
  2. Map the quarters to their xtick positions using groupby.ngroup:

    df['xticks'] = df.groupby('quarter').ngroup()
    
    #          Date   Fruit  type  quarter  xticks
    # 0  2017-01-01  Orange    FP   2017Q1       0
    # 1  2017-04-01  Orange    CP   2017Q2       1
    # 2  2017-07-01  Orange    CP   2017Q3       2
    # 3  2017-10-08  Orange    CP   2017Q4       3
    # 4  2017-01-01   Apple    NP   2017Q1       0
    # 5  2017-04-01   Apple    CP   2017Q2       1
    # 6  2017-07-01  Banana    NP   2017Q3       2
    # 7  2017-10-08  Orange    CP   2017Q4       3
    
  3. Find the min and max xticks to get the endpoints per Fruit using groupby.agg:

    ends = df.groupby('Fruit')['xticks'].agg(['min', 'max'])
    
    #         min  max
    # Fruit           
    # Apple     0    1
    # Banana    2    2
    # Orange    0    3
    
  4. Use ax.hlines to plot a horizontal line per Fruit from min-endpoint to max-endpoint:

    g = sns.catplot(x='quarter', y='Fruit', hue='type', kind='swarm', s=8, data=df)
    ax = g.axes[0, 0]
    ax.hlines(ends.index, ends['min'], ends['max'])
    

Upvotes: 2

kha
kha

Reputation: 353

You just need to enable the horizontal grid for the chart as follows:

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

dta = pd.DataFrame(
    columns=["Date", "Fruit", "type"],
    data=[
        ["2017-01-01", "Orange", "FP"],
        ["2017-04-01", "Orange", "CP"],
        ["2017-07-01", "Orange", "CP"],
        ["2017-10-08", "Orange", "CP"],
        ["2017-01-01", "Apple", "NP"],
        ["2017-04-01", "Apple", "CP"],
        ["2017-07-01", "Banana", "NP"],
        ["2017-10-08", "Orange", "CP"],
    ],
)


dta["quarter"] = pd.PeriodIndex(dta.Date, freq="Q")
sns.catplot(x="quarter", y="Fruit", hue="type", kind="swarm", data=dta)
plt.grid(axis='y')
plt.show()

Preview

enter image description here

Upvotes: 1

Related Questions