bli
bli

Reputation: 8194

setting a legend matching the colours in pyplot.scatter

Suppose my data is organized in the following way:

x_values = [6.2, 3.6, 7.3, 3.2, 2.7]
y_values = [1.5, 3.2, 5.4, 3.1, 2.8]
colours = [1, 1, 0, 1, -1]
labels = ["a", "a", "b", "a", "c"]

I want to make a scatterplot with this:

axis = plt.gca()
axis.scatter(x_values, y_values, c=colours)

I want a legend with 3 categories: "a", "b" and "c".

Can I use the labels list to make this legend, given that the categories in this list match the order of the points in the colours list?

Do I need to run the scatter command separately for each category?

Upvotes: 4

Views: 6021

Answers (3)

Alexander Chervov
Alexander Chervov

Reputation: 944

Just a remark, not exactly answering the question:

If use "seaborn" it would be EXACTLY ONE LINE:

import seaborn as sns 
x_values = [6.2, 3.6, 7.3, 3.2, 2.7]
y_values = [1.5, 3.2, 5.4, 3.1, 2.8]
#colors = [1, 1, 0, 1, -1]
labels = ["a", "a", "b", "a", "c"]
ax = sns.scatterplot(x=x_values, y=y_values, hue=labels)

enter image description here

PS

But the question is about matplotlib, so. We have answers above, also one might look at: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/scatter_with_legend.html Subsection: "Automated legend creation".

However I feel not easy to modify those examples to what you need.

Upvotes: 1

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339660

If you want to use a colormap you can create a legend entry for each unique entry in the colors list as shown below. This approach works well for any number of values. The legend handles are the markers of a plot, such that they match with the scatter points.

import matplotlib.pyplot as plt

x_values = [6.2, 3.6, 7.3, 3.2, 2.7]
y_values = [1.5, 3.2, 5.4, 3.1, 2.8]
colors = [1, 1, 0, 1, -1]
labels = ["a", "a", "b", "a", "c"]
clset = set(zip(colors, labels))

ax = plt.gca()
sc = ax.scatter(x_values, y_values, c=colors, cmap="brg")

handles = [plt.plot([],color=sc.get_cmap()(sc.norm(c)),ls="", marker="o")[0] for c,l in clset ]
labels = [l for c,l in clset]
ax.legend(handles, labels)

plt.show()

enter image description here

Upvotes: 3

Martin Evans
Martin Evans

Reputation: 46779

You can always make your own legend as follows:

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

x_values = [6.2, 3.6, 7.3, 3.2, 2.7]
y_values = [1.5, 3.2, 5.4, 3.1, 2.8]

a = 'red'
b = 'blue'
c = 'yellow'

colours = [a, a, b, a, c]
labels = ["a", "a", "b", "a", "c"]

axis = plt.gca()
axis.scatter(x_values, y_values, c=colours)

# Create a legend
handles = [mpatches.Patch(color=colour, label=label) for label, colour in [('a', a), ('b', b), ('c', c)]]
plt.legend(handles=handles, loc=2, frameon=True)

plt.show()

Which would look like:

plot with legend

Upvotes: 0

Related Questions