Reputation: 3497
I have the following code:
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.graphics.factorplots import interaction_plot
a = np.array( [ item for item in [ 'a1', 'a2', 'a3' ] for _ in range(30) ] )
b = np.array( [ item for _ in range(45) for item in [ 'b1', 'b2' ] ] )
np.random.seed(123)
mse = np.ravel( np.column_stack( (np.random.normal(-1, 1, size=45 ), np.random.normal(2, 0.5, size=45 ) )) )
f = interaction_plot( a, b, mse )
Which gives:
Is there an easy way to add error bars to each point directly?
f.axes.errorbar()?
Or is it better to just make the plot directly with matplotlib?
Upvotes: 0
Views: 1947
Reputation: 1
Maybe a dubious way to achieve error bars is to use axes.errorbar(), but that was my workaround. Doing so adds another line to the chart that you then need to align with the lines of the interaction plot.
Upvotes: 0
Reputation: 3497
Well, it seems that the feature is not yet directly supported, so I just decided to directly modify the source code and created a new function. I post it here, maybe it can be of use for somebody.
def int_plot(x, trace, response, func=np.mean, ax=None, plottype='b',
xlabel=None, ylabel=None, colors=[], markers=[],
linestyles=[], legendloc='best', legendtitle=None,
# - - - My changes !!
errorbars=False, errorbartyp='std',
# - - - -
**kwargs):
data = DataFrame(dict(x=x, trace=trace, response=response))
plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()
# - - - My changes !!
if errorbars:
if errorbartyp == 'std':
yerr = data.groupby(['trace', 'x']).aggregate( lambda xx: np.std(xx,ddof=1) ).reset_index()
elif errorbartyp == 'ci95':
yerr = data.groupby(['trace', 'x']).aggregate( t_ci ).reset_index()
else:
raise ValueError("Type of error bars %s not understood" % errorbartyp)
# - - - - - - -
n_trace = len(plot_data['trace'].unique())
if plottype == 'both' or plottype == 'b':
for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
# trace label
label = str(group['trace'].values[0])
# - - - My changes !!
if errorbars:
ax.errorbar(group['x'], group['response'],
yerr=yerr.loc[ yerr['trace']==values ]['response'].values,
color=colors[i], ecolor='black',
marker=markers[i], label='',
linestyle=linestyles[i], **kwargs)
# - - - - - - - - - -
ax.plot(group['x'], group['response'], color=colors[i],
marker=markers[i], label=label,
linestyle=linestyles[i], **kwargs)
With that, I could get this plot:
f = int_plot( a, b, mse, errorbars=True, errorbartyp='std' )
Note: The code can also use the functiont_ci()
for aggregation of the error bars. I defined the function like this:
def t_ci( x, C=0.95 ):
from scipy.stats import t
x = np.array( x )
n = len( x )
tstat = t.ppf( (1-C)/2, n )
return np.std( x, ddof=1 ) * tstat / np.sqrt( n )
Again, I just tweaked the function a bit to fit my current needs. The original function can be found here :)
Upvotes: 2