Alex W.
Alex W.

Reputation: 174

Customize x and y labels in matplotlib scatter plot

I have two lists xs and ys of equal length that I use to draw a scatter plot:

import random
import matplotlib.pyplot as plt

xs = [random.randrange(0,100) for i in range(50)]
ys = [random.randrange(0,100) for i in range(50)]

plt.scatter(xs,ys)

However, I don't want the standard axis labels but rather labels inferred from, e.g. the following dictionaries:

x_labels = { 40 : "First", 52 : "Second", 73: "Third" , 99: "Forth" }
y_labels = { 10 : "FIRST", 80 : "SECOND" }

So what I'm trying to do is to have a scatter plot with the label "First" at x = 40, "Second" at x = 73 and so on, as well as "FIRST" at y = 10 and "SECOND" at y = 80. Unfortunately, I haven't found a way how to achieve this.

Thanks a lot!

Upvotes: 2

Views: 1929

Answers (1)

JohanC
JohanC

Reputation: 80329

To display the tick labels at the desired position, you can use:

plt.xticks(list(x_labels.keys()), x_labels.values())
plt.yticks(list(y_labels.keys()), y_labels.values())

As you noted, this has as consequence that the coordinates aren't displayed any more in the status bar.

A workaround to get the coordinates displayed as well as custom ticks, is using a custom tick formattter. Such a formatter gets two arguments: an x value and a pos. pos is None when displaying the coordinates in the status bar, but is set for the tick labels. So, checking on pos not None, the formatter can return the desired label, while otherwise the number formatted as a string can be returned. The tick positions still need to be set via plt.xticks(), but not the labels.

Here is an example:

import random
import matplotlib.pyplot as plt
from matplotlib import ticker

@ticker.FuncFormatter
def major_x_formatter(x, pos):
    if pos is not None:
        return f"{x_labels_list[pos]}"
    x_r = int(round(x))
    if x_r in x_labels:
        return f"{x:.0f}:{x_labels[x_r]}"
    else:
        return f"{x:.2f}"

@ticker.FuncFormatter
def major_y_formatter(y, pos):
    if pos is not None:
        return f"{y_labels_list[pos]}"
    y_r = int(round(y))
    if y_r in y_labels:
        return f"{y:.0f}:{y_labels[y_r]}"
    else:
        return f"{y:.2f}"

xs = [random.randrange(0,100) for i in range(50)]
ys = [random.randrange(0,100) for i in range(50)]

plt.scatter(xs,ys)

x_labels = { 40 : "First", 52 : "Second", 73: "Third" , 99: "Forth" }
x_labels_list = list(x_labels.values())
y_labels = { 10 : "FIRST", 80 : "SECOND" }
y_labels_list = list(y_labels.values())
plt.xticks(list(x_labels.keys()))
plt.yticks(list(y_labels.keys()))
plt.gca().xaxis.set_major_formatter(major_x_formatter)
plt.gca().yaxis.set_major_formatter(major_y_formatter)

plt.show()

resulting plot

Upvotes: 2

Related Questions