astromonerd
astromonerd

Reputation: 937

How do I draw a line connecting subplots in pyplot?

I would like to create lines between subplots in pyplot as illustrated by the red dashed lines here (added in a pdf editor).

enter image description here

I've read documentation on connectionpatch, but I'm having difficulty making sense of the examples well enough to translate to my particular case. For my case, I've included a simplified version of my code, using the same axes structures in case that is relevant. How do I create these dashed lines between subplots?

import numpy as np
import matplotlib.pyplot as plt

# Create a 2 x 2 grid: (row, column) 
fig, ax = plt.subplots(2,2) 

# Create a subplot to share common x and y labels
fig.add_subplot(111, frameon=False)
plt.tick_params(
top='off', 
bottom='off', 
left='off', 
right='off')
plt.grid(False)
plt.xlabel('x')
plt.ylabel('function(x)')

# x-axis
x = np.linspace(0,2*np.pi,100)

# Top left
ax[0,0].tick_params(
axis='both',       
which='both',     
bottom=False,      
left=False,
top=False,         
right=False,
labelbottom=False, 
labelleft=False,)
ax[0,0].plot(x,np.sin(x),color='grey')

# Top Right
ax[0,1].tick_params(
axis='both',       
which='both',      
bottom=False,   
left=False,
top=False,        
right=False,
labelbottom=False, 
labelleft=False,)
ax[0,1].plot(x,np.sin(2*x),color='grey')


# Bottom Left
ax[1,0].tick_params(
axis='both',       
which='both',      
bottom=False,      
left=False,
top=False,         
right=False,
labelbottom=False, 
labelleft=False,)
ax[1,0].plot(x,np.cos(x), color='black')

# Bottom Right
ax[1,1].tick_params(
axis='both',      
which='both',     
bottom=False,      
left=False,
top=False,         
right=False,
labelbottom=False, 
labelleft=False,)
ax[1,1].plot(x,np.cos(2*x), color='black')

plt.tight_layout(h_pad=2.5)

Upvotes: 0

Views: 1867

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339310

The ConnectionPatch example shows how to use a ConnectionPatch to connect two axes. For your case you would do it like this:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch

fig, axes = plt.subplots(2,2) 

# Create a subplot to share common x and y labels
frameax = fig.add_subplot(111, frameon=False)
frameax.grid(False)
frameax.set_xlabel('x', labelpad=10)
frameax.set_ylabel('function(x)',labelpad=10)

for ax in list(axes.flat) + [frameax]:
    ax.tick_params(axis='both', which='both', 
                   bottom=False, left=False, top=False, right=False,
                   labelbottom=False, labelleft=False)

# x-axis
x = np.linspace(0,2*np.pi,100)

axes[0,0].plot(x,np.sin(x),color='grey')
axes[0,1].plot(x,np.sin(2*x),color='grey')
axes[1,0].plot(x,np.cos(x), color='black')
axes[1,1].plot(x,np.cos(2*x), color='black')

kw = dict(linestyle="--", color="red")
cp1 = ConnectionPatch((.5, 0), (.5, 1), "axes fraction", "axes fraction",
                      axesA=axes[0,0], axesB=axes[1,0], **kw)
cp2 = ConnectionPatch((.5, 0), (.5, 1), "axes fraction", "axes fraction",
                      axesA=axes[0,1], axesB=axes[1,1], **kw)

for cp in (cp1, cp2):
    axes[1,1].add_artist(cp)

plt.show()

enter image description here

Upvotes: 4

Related Questions