Reputation: 191
I was asked to generate a few graphs used for Multiple Baseline Study design, which is a specialized type of graph. I took it as an opportunity to learn a bit more of Matplotlib and Pandas, but one thing I struggle with is the dividing line which divides BASE and INTERVENTION. I need it to continue through multiple subplots and also scale well. Is there any way to accomplish such a thing? I have tried experimenting with Lines.Line2D and also ConnectionPatch, but I am stuck on the scaling and determining position correctly.
My (simple) code so far.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
y = np.array([0,1,2,3,4])
fig, axs = plt.subplots(3, sharex=True, sharey=True)
fig.suptitle("I1 - Reakce na změnu prvku")
axs[0].plot(df.index,df['A'], color='lightblue', label="A")
axs[1].plot(df.index,df['N'], color='darkblue', label="N")
axs[2].plot(df.index,df['P'], color='blue', label="P")
plt.yticks(np.arange(y.min(), y.max(), 1))
plt.show()
My plot so far (result of the code above):
Sample graph for a context:
Upvotes: 5
Views: 1749
Reputation: 19565
Borrowing from Pablo's helpful answer, it seems using fig.transFigure
can access coordinates in each subplot, and you can draw lines between all of these coordinates. This is probably the best method as it makes the starting and ending points straightforward to determine. Since your x-coordinates are conveniently from 1-12, you can also plot each subplot in two parts to leave a gap between points for the annotation line to go through.
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
y = np.array([0,1,2,3,4])
## recreate your data
df = pd.DataFrame({
'A':[0, 1, 1, 1, 2, 2, 3, 2, 3] + [float("nan")]*3,
'N':[1, 0, 0, 2, 1, 1, 2, 3, 3, 3, 3, 3],
'P':[0, 1, 1, 1, 2, 1, 1, 1, 2, 3, 3, 3],
},
index=range(1,13)
)
fig, axs = plt.subplots(3, sharex=True, sharey=True)
fig.suptitle("I1 - Reakce na změnu prvku")
## create a gap in the line
axs[0].plot(df.index[0:3],df['A'][0:3], color='lightblue', label="A", marker='.')
axs[0].plot(df.index[3:12],df['A'][3:12], color='lightblue', label="A", marker='.')
## create a gap in the line
axs[1].plot(df.index[0:8],df['N'][0:8], color='darkblue', label="N", marker='.')
axs[1].plot(df.index[8:12],df['N'][8:12], color='darkblue', label="N", marker='.')
## create a gap in the line
axs[2].plot(df.index[0:10],df['P'][0:10], color='blue', label="P", marker='.')
axs[2].plot(df.index[10:12],df['P'][10:12], color='blue', label="P", marker='.')
plt.yticks(np.arange(y.min(), y.max(), 1))
transFigure = fig.transFigure.inverted()
## Since your subplots have a ymax value of 3, setting the end y-coordinate
## of each line to just above that value should help it display outside of the figure
coord1 = transFigure.transform(axs[0].transData.transform([3.5,3]))
coord2 = transFigure.transform(axs[1].transData.transform([3.5,3.5]))
coord3 = transFigure.transform(axs[1].transData.transform([8.5,3.5]))
coord4 = transFigure.transform(axs[2].transData.transform([8.5,3.5]))
coord5 = transFigure.transform(axs[2].transData.transform([10.5,3.5]))
coord6 = transFigure.transform(axs[2].transData.transform([10.5,0]))
## add a vertical dashed line
line1 = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]),
transform=fig.transFigure,
ls='--',
color='grey')
## add a horizontal dashed line
line2 = matplotlib.lines.Line2D((coord2[0],coord3[0]),(coord2[1],coord3[1]),
transform=fig.transFigure,
ls='--',
color='grey')
## add a vertical dashed line
line3 = matplotlib.lines.Line2D((coord3[0],coord4[0]),(coord3[1],coord4[1]),
transform=fig.transFigure,
ls='--',
color='grey')
## add a horizontal dashed line
line4 = matplotlib.lines.Line2D((coord4[0],coord5[0]),(coord4[1],coord5[1]),
transform=fig.transFigure,
ls='--',
color='grey')
## add a vertical dashed line
line5 = matplotlib.lines.Line2D((coord5[0],coord6[0]),(coord5[1],coord6[1]),
transform=fig.transFigure,
ls='--',
color='grey')
fig.lines.extend([line1, line2, line3, line4, line5])
plt.show()
Upvotes: 4
Reputation: 40707
My instinct, for this kind of problem is to draw a line in figure coordinates. The one issue I had was finding the position of the center region between consecutive axes. My code is ugly, but it works, and is independent of the relative size of each axes, or the spacing between axes, as demonstrated below:
from matplotlib.lines import Line2D
def grouper(iterable, n, fillvalue=None):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
from itertools import zip_longest
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
xconn = [0.15, 0.35, 0.7] # position of the vertical lines in each subplot, in data coordinates
fig, axs = plt.subplots(3,1, gridspec_kw=dict(hspace=0.6, height_ratios=[2,0.5,1]))
#
# Draw the separation line, should be done at the very end when the limits of the axes have been set etc.
#
# convert the value of xconn in each axis to figure coordinates
xconn = [fig.transFigure.inverted().transform(ax.transData.transform([x,0]))[0] for x,ax in zip(xconn,axs)]
yconn = [] # y-values of the connecting lines, in figure coordinates
for ax in axs:
bbox = ax.get_position()
yconn.extend([bbox.y1, bbox.y0])
# replace each pairs of values corresponding to the bottom and top of each pairs of axes by the average
yconn[1:-1] = np.ravel([[np.mean(ys)]*2 for ys in grouper(yconn[1:-1], 2)]).tolist()
l = Line2D(np.repeat(xconn,2), yconn, transform=fig.transFigure, ls='--', lw=1, c='k')
fig.add_artist(l)
Upvotes: 3