KOB
KOB

Reputation: 4555

How to plot x and y intercepts at different points of a curve

I have trained a binary classification model.

I am able to get pairs of precision and recall values at different decision thresholds of the model as such:

test_prob = model.predict_proba(test_x)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(test_y, test_prob)

I can plot the PR Curve using matplotlib like:

plt.plot(recalls, precisions, label=f"Chargbacks (AUC = {round(pr_auc, 2)})", c="b")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.show()

and this produces this plot:

enter image description here

I can also create a dataframe of the corresponding precision and recall pairs for different decision thresholds like this:

thresholds = pd.DataFrame(
   {
        "Threshold": thresholds, 
        "Precision": precisions[:-1], 
        "Recall": recalls[:-1]
   }
)

and this produces this dataframe:

     Threshold  Precision    Recall
0     0.000000   0.005016  1.000000
1     0.002222   0.056515  0.990991
2     0.010000   0.056555  0.990991
3     0.020000   0.113995  0.989189
4     0.030000   0.163076  0.981982
5     0.031667   0.203295  0.978378
6     0.031667   0.203371  0.978378
7     0.040000   0.203447  0.978378
8     0.050000   0.243341  0.971171
9     0.060000   0.282347  0.971171
10    0.070000   0.321128  0.963964
11    0.080000   0.355898  0.956757
12    0.090000   0.383883  0.944144
13    0.100000   0.405594  0.940541
14    0.110000   0.431063  0.935135
15    0.120000   0.460036  0.933333
16    0.130000   0.484082  0.931532
17    0.140000   0.508374  0.929730
18    0.150000   0.530864  0.929730
19    0.160000   0.550694  0.929730
20    0.170000   0.571109  0.918919
21    0.180000   0.587082  0.917117
22    0.190000   0.607914  0.913514
23    0.200000   0.622850  0.913514
24    0.210000   0.644955  0.909910
25    0.220000   0.653696  0.908108
26    0.230000   0.665779  0.900901
27    0.240000   0.680384  0.893694
28    0.250000   0.688456  0.891892
29    0.260000   0.698300  0.888288
30    0.270000   0.700855  0.886486
31    0.280000   0.706052  0.882883
32    0.290000   0.711790  0.881081
33    0.300000   0.719764  0.879279
34    0.310000   0.726727  0.872072
35    0.320000   0.730594  0.864865
36    0.330000   0.735069  0.864865
37    0.340000   0.744946  0.863063
38    0.350000   0.750392  0.861261
39    0.360000   0.756757  0.857658
40    0.370000   0.761218  0.855856
41    0.380000   0.766990  0.854054
42    0.390000   0.768852  0.845045
43    0.400000   0.777778  0.845045
44    0.410000   0.781513  0.837838
45    0.420000   0.787053  0.832432
46    0.430000   0.791096  0.832432
47    0.439630   0.792746  0.827027
48    0.440000   0.792388  0.825225
49    0.450000   0.793043  0.821622
50    0.460000   0.793345  0.816216
51    0.470000   0.799645  0.812613
52    0.480000   0.803220  0.809009
53    0.490000   0.805755  0.807207
54    0.500000   0.809872  0.798198
55    0.510000   0.809524  0.796396
56    0.520000   0.814815  0.792793
57    0.530000   0.819887  0.787387
58    0.540000   0.823864  0.783784
59    0.550000   0.825670  0.776577
60    0.560000   0.826590  0.772973
61    0.570000   0.828125  0.763964
62    0.580000   0.827789  0.762162
63    0.590000   0.832016  0.758559
64    0.600000   0.831349  0.754955
65    0.610000   0.832335  0.751351
66    0.620000   0.834694  0.736937
67    0.630000   0.836066  0.735135
68    0.640000   0.844075  0.731532
69    0.650000   0.845511  0.729730
70    0.660000   0.844211  0.722523
71    0.670000   0.846809  0.717117
72    0.680000   0.846482  0.715315
73    0.690000   0.850649  0.708108
74    0.700000   0.857768  0.706306
75    0.710000   0.863135  0.704505
76    0.720000   0.868889  0.704505
77    0.730000   0.876404  0.702703
78    0.740000   0.876147  0.688288
79    0.750000   0.875862  0.686486
80    0.760000   0.874126  0.675676
81    0.770000   0.874408  0.664865
82    0.780000   0.872596  0.654054
83    0.790000   0.882064  0.646847
84    0.800000   0.883085  0.639640
85    0.810000   0.887218  0.637838
86    0.820000   0.890585  0.630631
87    0.830000   0.890625  0.616216
88    0.840000   0.898396  0.605405
89    0.850000   0.898907  0.592793
90    0.860000   0.899441  0.580180
91    0.870000   0.901449  0.560360
92    0.880000   0.903904  0.542342
93    0.890000   0.907407  0.529730
94    0.900000   0.911672  0.520721
95    0.910000   0.912621  0.508108
96    0.920000   0.915541  0.488288
97    0.930000   0.916955  0.477477
98    0.940000   0.927536  0.461261
99    0.950000   0.932331  0.446847
100   0.960000   0.931174  0.414414
101   0.970000   0.939130  0.389189
102   0.980000   0.938095  0.354955
103   0.990000   0.935484  0.313514
104   1.000000   0.928058  0.232432

On the same plot as the PR Curve, I now want to plot horizontal dotted lines at y-values [0.1, 0.2, ..., 0.9] (the closest values, if available, to these in the dataframe above) that hit the blue curve, and then drop vertically to the x-axis. Each of these should be labelled as the corresponding 'Threshold' from the dataframe above.

How can I achieve this?

The final plot should look something like this:

enter image description here

EDIT:

Instead of drawing the intercepts at every precision = [0.1, ..., 0.9], it would make more sense to plot them for every threshold = [0.1, ..., 0.9], but the same question still stands with this adjustment.

Upvotes: 0

Views: 255

Answers (1)

JohanC
JohanC

Reputation: 80409

idx = (np.abs(threshold - t)).argmin() finds the index of the value in threshold nearest to t. This index can be used to draw the lines and position the text. Lines for a given precision can be drawn similarly.

import matplotlib.pyplot as plt
import numpy as np

threshold = np.array([0.0, 0.002222, 0.01, 0.02, 0.03, 0.031667, 0.031667, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43, 0.43963, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0])
precisions =  np.array([0.005016, 0.056515, 0.056555, 0.113995, 0.163076, 0.203295, 0.203371, 0.203447, 0.243341, 0.282347, 0.321128, 0.355898, 0.383883, 0.405594, 0.431063, 0.460036, 0.484082, 0.508374, 0.530864, 0.550694, 0.571109, 0.587082, 0.607914, 0.62285, 0.644955, 0.653696, 0.665779, 0.680384, 0.688456, 0.6983, 0.700855, 0.706052, 0.71179, 0.719764, 0.726727, 0.730594, 0.735069, 0.744946, 0.750392, 0.756757, 0.761218, 0.76699, 0.768852, 0.777778, 0.781513, 0.787053, 0.791096, 0.792746, 0.792388, 0.793043, 0.793345, 0.799645, 0.80322, 0.805755, 0.809872, 0.809524, 0.814815, 0.819887, 0.823864, 0.82567, 0.82659, 0.828125, 0.827789, 0.832016, 0.831349, 0.832335, 0.834694, 0.836066, 0.844075, 0.845511, 0.844211, 0.846809, 0.846482, 0.850649, 0.857768, 0.863135, 0.868889, 0.876404, 0.876147, 0.875862, 0.874126, 0.874408, 0.872596, 0.882064, 0.883085, 0.887218, 0.890585, 0.890625, 0.898396, 0.898907, 0.899441, 0.901449, 0.903904, 0.907407, 0.911672, 0.912621, 0.915541, 0.916955, 0.927536, 0.932331 , 0.931174, 0.93913, 0.938095, 0.935484, 0.928058])
recalls = np.array([1.0, 0.990991, 0.990991, 0.989189, 0.981982, 0.978378, 0.978378, 0.978378, 0.971171, 0.971171, 0.963964, 0.956757, 0.944144, 0.940541, 0.935135, 0.933333, 0.931532, 0.92973, 0.92973, 0.92973, 0.918919, 0.917117, 0.913514, 0.913514, 0.90991, 0.908108, 0.900901, 0.8936940, 0.891892, 0.888288, 0.886486, 0.882883, 0.881081, 0.879279, 0.872072, 0.864865, 0.864865, 0.863063, 0.861261, 0.857658, 0.855856, 0.854054, 0.845045, 0.845045, 0.837838, 0.832432, 0.832432, 0.8270270000000001, 0.825225, 0.821622, 0.816216, 0.812613, 0.809009, 0.807207, 0.798198, 0.796396, 0.792793, 0.787387, 0.783784, 0.776577, 0.772973, 0.763964, 0.762162, 0.758559, 0.754955, 0.751351, 0.736937, 0.735135, 0.731532, 0.72973, 0.722523, 0.717117, 0.715315, 0.708108, 0.706306, 0.704505, 0.704505, 0.702703, 0.688288, 0.686486, 0.675676, 0.664865, 0.654054, 0.646847, 0.63964, 0.637838, 0.630631, 0.616216, 0.605405, 0.592793, 0.58018, 0.56036, 0.542342, 0.52973, 0.520721, 0.508108, 0.488288, 0.477477, 0.461261, 0.446847, 0.414414, 0.389189, 0.354955, 0.313514, 0.232432])

fig, axs = plt.subplots(ncols=2, figsize=(10, 4))

for ax in axs:
    ax.plot(recalls, precisions, label=f"Chargbacks (AUC = {round(0.85, 2)})", c="b")

    if ax == axs[0]:
        for p in np.arange(0.1, 1, 0.1):
            idx = (np.abs(precisions - p)).argmin()
            ax.plot([recalls[idx], recalls[idx], 0], [0, precisions[idx], precisions[idx]], c='crimson')
            ax.text(0.02,precisions[idx], t, color='crimson', fontsize=10, va='bottom', ha='left' )
    else:
        for i in range(1, 10):
            t = i * 0.1
            idx = (np.abs(threshold - t)).argmin()
            ax.plot([recalls[idx], recalls[idx], 0], [0, precisions[idx], precisions[idx]], c='crimson')
            ax.text(0.02 if i % 2 == 1 else 0.07, precisions[idx], threshold[idx], color='black', fontsize=10, va='bottom', ha='left' )

    ax.set_xlim(xmin=0)
    ax.set_ylim(ymin=0)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.legend()
plt.show()

example plot

Upvotes: 1

Related Questions