Reputation: 187
I want to put a marker in a group of matplotlib plots that satisfy a condition. The plots have different ranges (xlim, ylim) so I'd like it to be independent of these values.
I made a simple generic script to explain the idea:
import random
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
for i in range(30):
clear_output(wait=True)
y=np.random.normal(random.randint(-10,10),1*random.randint(0,200),20)
x=np.random.normal(random.randint(-10,10),1*random.randint(0,200),20)
plt.grid(True)
plt.xlim([-200,200])
plt.ylim([-200,200])
plt.scatter(0,0,color='blue', marker="o", alpha=0.1, s=3000)
plt.scatter(x,y)
if np.mean(x)**2. + np.mean(y)**2. <= 25**2:
plt.scatter(np.mean(x),np.mean(y),color='lightgreen',marker='$\odot$',s=5000)
plt.scatter(100,0,color='lightgreen',marker='$\mathrm{correct!}$', s=7000)
if np.mean(x)**2. + np.mean(y)**2. > 25**2:
plt.scatter(0,0,color='red',marker='$X$',s=1000)
plt.scatter(np.mean(x),np.mean(y),color='red',marker='$\odot$',s=500)
plt.show()
In this you see there's a green marker that is displaying "correct!". I would like to put this as some kind of legend, with a fixed location. How could I do it?
Upvotes: 0
Views: 404
Reputation: 19590
You can build a custom legend and specify the shape, color, and marker type, and display different legends based on whether you want to label a marker as correct or incorrect. For example:
import random
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
from matplotlib.lines import Line2D
legend_correct_element = [Line2D([0], [0], marker='s', color='w', label='Correct', markerfacecolor='lightgreen', markersize=15)]
legend_incorrect_element = [Line2D([0], [0], marker='X', color='w', label='Incorrect', markerfacecolor='red', markersize=15)]
for i in range(30):
clear_output(wait=True)
## for reproducibility
np.random.seed(42)
y=np.random.normal(random.randint(-10,10),1*random.randint(0,200),20)
x=np.random.normal(random.randint(-10,10),1*random.randint(0,200),20)
plt.grid(True)
plt.xlim([-200,200])
plt.ylim([-200,200])
plt.scatter(0,0,color='blue', marker="o", alpha=0.1, s=3000)
plt.scatter(x,y)
if np.mean(x)**2. + np.mean(y)**2. <= 25**2:
plt.scatter(np.mean(x),np.mean(y),color='lightgreen',marker='$\odot$',s=5000)
plt.scatter(100,0,color='lightgreen',marker='$\mathrm{correct!}$', s=7000)
plt.legend(handles=legend_correct_element)
if np.mean(x)**2. + np.mean(y)**2. > 25**2:
plt.scatter(0,0,color='red',marker='$X$',s=1000)
plt.scatter(np.mean(x),np.mean(y),color='red',marker='$\odot$',s=500)
plt.legend(handles=legend_incorrect_element)
plt.show()
Upvotes: 1