Sahil
Sahil

Reputation: 276

Hiow to control color of the pixels in matplotlib.pyplot.imshow?

I want to represent a 2X2 matrix(array) using matplotlib.pyplot.imshow(). It works fine but I want to control the colors of each pixel myself instead of the function doing it. Like, I have an array say:

for i in range(N):
    for j in range(N):
        x = np.random.random()
        if x <= 0.4:
            lat[i, j] = 0
        elif 0.4 < x <= 0.5:
            lat[i, j] = 1            
        elif 0.5 < x <= 0.6:
            lat[i, j] = 2           
        else:
            lat[i, j] = 3

This generates the matrix I want. Now while using the plt.imshow() function, I want to use specific colors if the matrix element has a specific value(in this case it's 0,1,2 or 3). How can I do that?

Upvotes: 2

Views: 476

Answers (1)

JohanC
JohanC

Reputation: 80309

You could create a LinearSegmentedColormap with the desired colors:

from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np

N = 5
lat = np.empty((N, N), dtype=np.int)
for i in range(N):
    for j in range(N):
        x = np.random.random()
        if x <= 0.4:
            lat[i, j] = 0
        elif 0.4 < x <= 0.5:
            lat[i, j] = 1
        elif 0.5 < x <= 0.6:
            lat[i, j] = 2
        else:
            lat[i, j] = 3

my_colors = ['crimson', 'lime', 'dodgerblue', 'gold'] # colors for 0, 1, 2 and 3
cmap = LinearSegmentedColormap.from_list('', my_colors, len(my_colors))
plt.imshow(lat, cmap=cmap, vmin=0, vmax=len(my_colors) - 1, alpha=0.4)
for i in range(lat.shape[0]):
    for j in range(lat.shape[1]):
        plt.text(j, i, lat[i, j])
plt.show()

example plot

PS: Note that numpy also has a function digitize to automate the first step of the code:

x = np.random.random((N, N))
lat = np.digitize(x, [0.4, 0.5, 0.6], right=True)

my_colors = ['fuchsia', 'lime', 'turquoise', 'gold']
cmap = LinearSegmentedColormap.from_list('', my_colors, len(my_colors))
plt.imshow(lat, cmap=cmap, vmin=0, vmax=len(my_colors) - 1)
for i in range(lat.shape[0]):
    for j in range(lat.shape[1]):
        plt.text(j, i, f'{x[i, j]:.2f}\n{lat[i, j]}', ha='center', va='center')
plt.show()

second example

Upvotes: 3

Related Questions