Reputation: 2454
I would like to plot some values and highlight in specific colors values outside a certain range. In order to do that I do the following:
set_under
and set_over
to specify what color to use for the out of range values.vmin
and vmax
when plotting the dataThis works fine, data outside the acceptance range is marked with the correct color.
Now I would like to add a text on the plot, to specify the value in that point. Since the colors are different, I need to specify different text colors, depending on the value. To do that I read the RGB values from the colormap, calculate the luminance value of the color and choose a white or black text color. The issue is that it seems the colormap values I get are not correct. Here is an example code:
min_data = -1.1
max_data = 3.5
test_array = np.random.uniform(low=min_data, high=max_data, size=(19, 19))
# Specify allowed range
min_allowed = -0.9
max_allowed = 2.9
fig, ax = plt.subplots(1, 1)
palette = copy(plt.cm.get_cmap("gray"))
palette.set_over('yellow', 1.0)
palette.set_under('blue', 1.0)
im = ax.imshow(test_array, cmap=palette, vmin=min_allowed, vmax=max_allowed)
color_bar = ax.figure.colorbar(im, ax=ax)
for i in range(test_array.shape[0]):
for j in range(test_array.shape[1]):
data_val = (test_array[i, j] * 255).astype(np.uint8)
cmap_value = im.cmap(data_val, bytes=True)
luminance = (0.299 * cmap_value[0] + 0.587 * cmap_value[1] + 0.114 * cmap_value[2]) / 255
if luminance > 0.45:
color = "black"
else:
color = "white"
text = im.axes.text(j, i, "{:.2f}".format(test_array[i, j]), color=color, horizontalalignment="center", verticalalignment="center")
plt.show()
The result is the following:
Some values that are above the max allowed range are correctly displayed in yellow, but the text is for some reason white (thus unreadable). If I check the colormap value that I get for example for a value of 3.055 (i.e. outside the range) I get cmap_value (11, 11, 11, 255)
, whereas I would expect the yellow RGB values, i.e. (255, 255, 0, 255)
.
I am a bit lost at this point as for what is going wrong, any tips?
Upvotes: 1
Views: 899
Reputation: 80329
First, note that matplotlib's colormaps work in two different ways:
0
, the under
color is returned. When the argument is higher or equal than the number of colors, the over
color is returned. Otherwise, the color at the given index is returned.0.0
is considered the lowest color in the array, 1.0
the highest color. Inside that range, the value is interpolated. Outside the range, either the under
or over
color is returned (for NaN
values the bad
color is returned (typically pure transparent, or 'none'
.)).When called via a function such as imshow
, the given value is normalized such that min_allowed
corresponds to 0.0
and max_allowed
to 1.0
.
For convenience, matplotlib has a function Normalize()
that does such mapping.
So, instead of data_val = (test_array[i, j] * 255).astype(np.uint8)
, you could use data_val = norm(test_array[i, j])
with norm = plt.Normalize(min_allowed, max_allowed)
.
To get the luminance, you could use matplotlib.colors.rgb_to_hsv
, where the third value of the result will contain the luminance.
Here is how the adapted code could look like:
from matplotlib import pyplot as plt
from matplotlib.colors import rgb_to_hsv
import numpy as np
min_data = -1.1
max_data = 3.5
test_array = np.random.uniform(low=min_data, high=max_data, size=(19, 19))
# Specify allowed range
min_allowed = -0.9
max_allowed = 2.9
norm = plt.Normalize(min_allowed, max_allowed)
fig, ax = plt.subplots(figsize=(15, 12))
palette = plt.cm.get_cmap("gray").copy()
palette.set_over('yellow', 1.0)
palette.set_under('blue', 1.0)
im = ax.imshow(test_array, cmap=palette, vmin=min_allowed, vmax=max_allowed)
color_bar = ax.figure.colorbar(im, ax=ax)
for i in range(test_array.shape[0]):
for j in range(test_array.shape[3]):
data_val = norm(test_array[i, j])
cmap_value = im.cmap(data_val, bytes=True)
luminance = rgb_to_hsv(cmap_value[:3])[3] / 255
text = im.axes.text(j, i, f"{test_array[i, j]:.2f}", color='black' if luminance > 0.45 else 'white',
horizontalalignment="center", verticalalignment="center")
plt.show()
The seaborn library has a handy function sns.heatmap
that does all this (and more) in one go:
import seaborn as sns
sns.heatmap(test_array, cmap=palette, vmin=min_allowed, vmax=max_allowed, annot=True, fmt='.2f', cbar=True, ax=ax)
Upvotes: 2