P Kuca
P Kuca

Reputation: 51

Plotly/Python creation of the nested x-axis bar-chart

I was wondering if Plotly can start support for nested X /multiple X axis? For instance if one is using standard “tip” data, adding a capability as implemented in (fivecents plot, JMP or Origin ) would be beneficial. i.e.

import plotly.express as px
df = px.data.tips()
%load_ext autoreload
%autoreload 2
%matplotlib inline
import fivecentplots as fcp
import pandas as pd
import numpy as np
import os, sys, pdb
osjoin = os.path.join
db = pdb.set_trace
fcp.boxplot(df=df, y=‘tip’, groups=[‘time’, ‘sex’, ‘day’], legend=‘smoker’)

enter image description here

would generate: nested X-axis bar-chart

If this capability already exist - please add a comment.

Upvotes: 5

Views: 929

Answers (1)

amance
amance

Reputation: 1770

It's doable, but takes a lot more steps than the method you're currently using. One solution to achieving this with plotly is by creating subplots with multidimensional axis, one for Lunch and one for Dinner, with zero space in between so it looks like one single plot.

import plotly.express as px
import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype

df = px.data.tips()

#set order for days
days = CategoricalDtype(
    ['Thur', 'Fri', 'Sat', 'Sun'],
    ordered=True
    )

df['day'] = df['day'].astype(days)

#sort df
df.sort_values(['time', 'sex', 'day'], inplace=True)

#create framework
fig = make_subplots(rows=1,
            cols=2,
            shared_yaxes=True,
            horizontal_spacing=0,
            column_widths=[7/11, 4/11])

#create "Dinner" boxplots
fig.add_trace(go.Box(x=[df['sex'][df['time']=='Dinner'].tolist(), df['day'][df['time']=='Dinner'].tolist()],
             y=df['tip'][df['time']=='Dinner'],
             boxpoints=False,
             pointpos=0,
             line=dict(color='gray',
                   width=1),
             fillcolor='white',
             showlegend=False),
          row=1,
          col=1)
#add "Dinner" smokers
fig.add_trace(go.Scatter(x=[df['sex'][(df['time']=='Dinner') & (df['smoker']=='Yes')].tolist(), df['day'][(df['time']=='Dinner') & (df['smoker']=='Yes')].tolist()],
             y=df['tip'][(df['time']=='Dinner') & (df['smoker']=='Yes')],
             mode='markers',
             marker=dict(color='red',
                     symbol='circle-open',
                     size=10),
             name='Yes'
             ),
          row=1,
          col=1)

#add "Dinner" non-smokers
fig.add_trace(go.Scatter(x=[df['sex'][(df['time']=='Dinner') & (df['smoker']=='No')].tolist(), df['day'][(df['time']=='Dinner') & (df['smoker']=='No')].tolist()],
             y=df['tip'][(df['time']=='Dinner') & (df['smoker']=='No')],
             mode='markers',
             marker=dict(color='green',
                     symbol='cross-thin-open',
                     size=10),
             name='No'
             ),
          row=1,
          col=1)

df_mean = df[['sex', 'day', 'tip']][df['time']=='Dinner'].groupby(['sex', 'day']).mean().reset_index().dropna()

#add "Dinner" mean line
fig.add_trace(go.Scatter(x=[df_mean['sex'].tolist(), df_mean['day'].tolist()],
             y=df_mean['tip'].tolist(),
             showlegend=False,
             marker=dict(color='black')
             ),
          row=1,
          col=1)

#create "Lunch" boxplots
fig.add_trace(go.Box(x=[df['sex'][df['time']=='Lunch'].tolist(), df['day'][df['time']=='Lunch'].tolist()],
             y=df['tip'][df['time']=='Lunch'],
             boxpoints=False,
             pointpos=0,
             line=dict(color='gray',
                   width=1),
             fillcolor='white',
             showlegend=False),
          row=1,
          col=2)
#add "Lunch" smokers
fig.add_trace(go.Scatter(x=[df['sex'][(df['time']=='Lunch') & (df['smoker']=='Yes')].tolist(), df['day'][(df['time']=='Lunch') & (df['smoker']=='Yes')].tolist()],
             y=df['tip'][(df['time']=='Lunch') & (df['smoker']=='Yes')],
             mode='markers',
             marker=dict(color='red',
                     symbol='circle-open',
                     size=10),
             showlegend=False
             ),
          row=1,
          col=2)
#add "Lunch" non-smokers
fig.add_trace(go.Scatter(x=[df['sex'][(df['time']=='Lunch') & (df['smoker']=='No')].tolist(), df['day'][(df['time']=='Lunch') & (df['smoker']=='No')].tolist()],
             y=df['tip'][(df['time']=='Lunch') & (df['smoker']=='No')],
             mode='markers',
             marker=dict(color='green',
                     symbol='cross-thin-open',
                     size=10),
             showlegend=False
             ),
          row=1,
          col=2)

df_mean = df[['sex', 'day', 'tip']][df['time']=='Lunch'].groupby(['sex', 'day']).mean().reset_index().dropna()

#add "Lunch" mean line
fig.add_trace(go.Scatter(x=[df_mean['sex'].tolist(), df_mean['day'].tolist()],
             y=df_mean['tip'].tolist(),
             showlegend=False,
             marker=dict(color='black')
             ),
          row=1,
          col=2)

fig.update_xaxes(title='Dinner', col=1)
fig.update_xaxes(title='Lunch', col=2)
fig.update_yaxes(title='tip', col=1)
fig.update_layout(legend_title='Smoker')
fig.show()

Figure

Performance Update:

Using loops to make it faster plus other minor changes

import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

df = px.data.tips()

#set order for days
days = CategoricalDtype(['Thur', 'Fri', 'Sat', 'Sun'],
                        ordered=True)

df['day'] = df['day'].astype(days)

#sort df
df = df.sort_values(['time', 'sex', 'day']).reset_index(drop=True)

#create framework
fig = make_subplots(rows=1,
                    cols=2,
                    shared_yaxes=True,
                    horizontal_spacing=0,
                    column_widths=[4/11, 7/11])

col_order = {'Lunch':1, 'Dinner':2}
for t, col in col_order.items():
    cond = df['time']==t
    #add mean line
    df_mean = df[cond].groupby(['sex', 'day'], observed=True)['tip'].mean().reset_index()
    fig.add_trace(go.Scatter(x=[df_mean['sex'].tolist(), df_mean['day'].tolist()],
                             y=df_mean['tip'].tolist(),
                             showlegend=False,
                             marker=dict(color='black')
                             ),
                  row=1,
                  col=col)
    #add boxplots
    fig.add_trace(go.Box(x=[df[cond]['sex'].tolist(), df[cond]['day'].tolist()],
                         y=df[cond]['tip'],
                         boxpoints=False,
                         pointpos=0,
                         line=dict(color='gray',
                                   width=1),
                         fillcolor='white',
                         showlegend=False),
                  row=1,
                  col=col)
    #add smokers and non-smokers
    sl = True if col==1 else False #avoid duplicate legend entries
    smoker_dict = {'Yes':['red', 'circle-open'], 'No':['green', 'cross-thin-open']}
    for s, marker_vals in smoker_dict.items():
        color = marker_vals[0]
        symbol = marker_vals[1]
        cond = (df['time']==t) & (df['smoker']==s)
        fig.add_trace(go.Box(x=[df[cond]['sex'].tolist(), df[cond]['day'].tolist()],
                                 y=df[cond]['tip'],
                                 marker=dict(color=color,
                                             symbol=symbol,
                                             size=10),
                                 fillcolor='rgba(255,255,255,0)',
                                 line_color='rgba(255,255,255,0)',
                                 boxpoints='all',
                                 pointpos=0,
                                 name=s,
                                 showlegend=sl,
                                 ),
                      row=1,
                      col=col)
    fig.update_xaxes(title=t, col=col)
fig.update_yaxes(title='tip', col=1)
fig.update_xaxes(showline=True,
                 linecolor='black',
                 linewidth=1,
                 mirror=True)
fig.update_yaxes(showline=True,
                 linecolor='black',
                 linewidth=1,
                 mirror=True)
fig.update_yaxes(mirror=False, col=1) #center line was too thick
fig.update_traces(selector=dict(type='box'), jitter=1) #optional jitter
fig.update_layout(legend_title='Smoker',
                  plot_bgcolor='whitesmoke',
                  height=800, width=800)
fig.show()

enter image description here

Upvotes: 5

Related Questions