Mate de Vita
Mate de Vita

Reputation: 1325

How to annotate a regression line with the proper text rotation

I have the following snippet of code to draw a best-fit line through a collections of points on a graph, and annotate it with the corresponding R2 value:

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

x = 50 * np.random.rand(20) + 50
y = 200 * np.random.rand(20)
plt.plot(x, y, 'o')

# k, n = np.polyfit(x, y, 1)
k, n, r, _, _ = scipy.stats.linregress(x, y)
line = plt.axline((0, n), slope=k, color='blue')
xy = line.get_xydata()
plt.annotate(
    f'$R^2={r**2:.3f}$',
    (xy[0] + xy[-1]) // 2,
    xycoords='axes fraction',
    ha='center', va='center_baseline',
    rotation=k, rotation_mode='anchor',
)

plt.show()

I have tried various different (x,y) pairs, different xycoords and other keyword parameters in annotate but I haven't been able to get the annotation to properly appear where I want it. How do I get the text annotation to appear above the line with proper rotation, located either at the middle point of the line, or at either end?

Upvotes: 1

Views: 753

Answers (1)

tdy
tdy

Reputation: 41327

1. Annotation coordinates

We cannot compute the coordinates using xydata here, as axline() just returns dummy xydata (probably due to the way matplotlib internally plots infinite lines):

print(line.get_xydata())
# array([[0., 0.],
#        [1., 1.]])

Instead we can compute the text coordinates based on the xlim():

xmin, xmax = plt.xlim()
xtext = (xmin + xmax) // 2
ytext = k*xtext + n

Note that these are data coordinates, so they should be used with xycoords='data' instead of 'axes fraction'.


2. Annotation angle

We cannot compute the angle purely from the line points, as the angle will also depend on the axis limits and figure dimensions (e.g., imagine the required rotation angle in a 6x4 figure vs 2x8 figure).

Instead we should normalize the calculation to both scales to get the proper visual rotation:

rs = np.random.RandomState(0)
x = 50 * rs.rand(20) + 50
y = 200 * rs.rand(20)
plt.plot(x, y, 'o')

# save ax and fig scales
xmin, xmax = plt.xlim()
ymin, ymax = plt.ylim()
xfig, yfig = plt.gcf().get_size_inches()

k, n, r, _, _ = scipy.stats.linregress(x, y)
plt.axline((0, n), slope=k, color='blue')

# restore x and y limits after axline
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)

# find text coordinates at midpoint of regression line
xtext = (xmin + xmax) // 2
ytext = k*xtext + n

# find run and rise of (xtext, ytext) vs (0, n)
dx = xtext
dy = ytext - n

# normalize to ax and fig scales
xnorm = dx * xfig / (xmax - xmin)
ynorm = dy * yfig / (ymax - ymin)

# find normalized annotation angle in radians
rotation = np.rad2deg(np.arctan2(ynorm, xnorm))

plt.annotate(
    f'$R^2={r**2:.3f}$',
    (xtext, ytext), xycoords='data',
    ha='center', va='bottom',
    rotation=rotation, rotation_mode='anchor',
)

Upvotes: 2

Related Questions