taurus
taurus

Reputation: 490

PyPlot Change Scatter Label When Points Overlap

I am graphing my predicted and actual results of an ML project using pyplot. I have a scatter plot of each dataset as a subplot and the Y values are elements of [-1, 0, 1]. I would to change the color of the points if both points have the same X and Y value but am not sure how to implement this. Here is my code so far:

import matplotlib.pyplot as plt

Y = [1, 0, -1, 0, 1]
Z = [1, 1, 1, 1, 1]

plt.subplots()
plt.title('Title')
plt.xlabel('Timestep')
plt.ylabel('Score')
plt.scatter(x = [i for i in range(len(Y))], y = Y, label = 'Actual')
plt.scatter(x = [i for i in range(len(Y))], y = Z, label = 'Predicted')
plt.legend()

Upvotes: 1

Views: 42

Answers (1)

Sheldore
Sheldore

Reputation: 39042

I would simply make use of NumPy indexing in this case. Specifically, first plot all the data points and then additionally highlight only those point which fulfill the condition X==Y and X==Z

import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()

Y = np.array([1, 0, -1, 0, 1])
Z = np.array([1, 1, 1, 1, 1])

X = np.arange(len(Y))

# Labels and titles here

plt.scatter(X, Y, label = 'Actual')
plt.scatter(X, Z, label = 'Predicted')

plt.scatter(X[X==Y], Y[X==Y], color='black', s=500)
plt.scatter(X[X==Z], Z[X==Z], color='red', s=500)
plt.xticks(X)
plt.legend()
plt.show()

enter image description here

Upvotes: 1

Related Questions