Luis
Luis

Reputation: 3497

How to add error bars to interaction plot (statsmodels)?

I have the following code:

import numpy as np
import matplotlib.pyplot as plt
from statsmodels.graphics.factorplots import interaction_plot

a = np.array( [ item for item in [ 'a1', 'a2', 'a3' ] for _ in range(30) ] )
b = np.array( [ item for _ in range(45) for item in [ 'b1', 'b2' ] ] )
np.random.seed(123)
mse = np.ravel( np.column_stack( (np.random.normal(-1, 1, size=45 ), np.random.normal(2, 0.5, size=45 ) )) )
f = interaction_plot( a, b, mse )

Which gives:

enter image description here


Is there an easy way to add error bars to each point directly?

f.axes.errorbar()?

Or is it better to just make the plot directly with matplotlib?

Upvotes: 0

Views: 1947

Answers (2)

Rune Nisbeth
Rune Nisbeth

Reputation: 1

Maybe a dubious way to achieve error bars is to use axes.errorbar(), but that was my workaround. Doing so adds another line to the chart that you then need to align with the lines of the interaction plot.

Upvotes: 0

Luis
Luis

Reputation: 3497

Well, it seems that the feature is not yet directly supported, so I just decided to directly modify the source code and created a new function. I post it here, maybe it can be of use for somebody.


def int_plot(x, trace, response, func=np.mean, ax=None, plottype='b',
                     xlabel=None, ylabel=None, colors=[], markers=[],
                     linestyles=[], legendloc='best', legendtitle=None,
# - - - My changes !!
                     errorbars=False, errorbartyp='std',
# - - - - 
                     **kwargs):

    data = DataFrame(dict(x=x, trace=trace, response=response))
    plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()

# - - - My changes !!
    if errorbars:
        if errorbartyp == 'std':
            yerr = data.groupby(['trace', 'x']).aggregate( lambda xx: np.std(xx,ddof=1) ).reset_index()
        elif errorbartyp == 'ci95':
            yerr = data.groupby(['trace', 'x']).aggregate( t_ci ).reset_index()
        else:
            raise ValueError("Type of error bars %s not understood" % errorbartyp)
# - - - - - - -
    n_trace = len(plot_data['trace'].unique())

    if plottype == 'both' or plottype == 'b':
        for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
            # trace label
            label = str(group['trace'].values[0])
# - - - My changes !!
            if errorbars:
                ax.errorbar(group['x'], group['response'], 
                            yerr=yerr.loc[ yerr['trace']==values ]['response'].values, 
                        color=colors[i], ecolor='black',
                        marker=markers[i], label='',
                        linestyle=linestyles[i], **kwargs)
# - - - - - - - - - - 
            ax.plot(group['x'], group['response'], color=colors[i],
                    marker=markers[i], label=label,
                    linestyle=linestyles[i], **kwargs)

With that, I could get this plot:

f = int_plot( a, b, mse, errorbars=True, errorbartyp='std' )

enter image description here


Note: The code can also use the functiont_ci() for aggregation of the error bars. I defined the function like this:

def t_ci( x, C=0.95 ):
    from scipy.stats import t

    x = np.array( x )
    n = len( x )
    tstat = t.ppf( (1-C)/2, n )
    return np.std( x, ddof=1 ) * tstat / np.sqrt( n )

Again, I just tweaked the function a bit to fit my current needs. The original function can be found here :)

Upvotes: 2

Related Questions