Reputation: 862
I have multiple images (numpy arrays) whose data values correspond to N different classes. Each image does not necessarily contain examples of each class. For example, there might be a total of 12 different classes (0:11), however, one image might only contain classes 1:9.
I would like to plot each image such that the color assigned to each class is the same across all images.
I've looked into several answers: here the accepted and popular answers didn't work across multiple images. here seems like it could work but I would really like to use a color map (from matplotlib import cm
) so as not to manually set colors. I would also like a means to create an appropriate colorbar containing all classes.
The code I've tried is below:
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
t1 = np.arange(9).reshape(3,3)
t2 = t1.copy()
t2[1,1] = 10
t3 = t2.copy()
t3[1,1] = 11
cmap = cm.get_cmap('tab20', 11)
fig, axs = plt.subplots(1,3)
axs[0].imshow(t1, cmap = cmap, vmin = 0, vmax = 11)
axs[1].imshow(t2, cmap = cmap, vmin = 0, vmax = 11)
axs[2].imshow(t3, cmap = cmap, vmin = 0, vmax = 11)
Upvotes: 0
Views: 965
Reputation: 1620
For future reference, in case you want to define your own colors and not a predefined cmap
, I created the following code specifically for this some time ago.
import matplotlib as mpl
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
C_p = 11 # Classes
colour_names = [ # Your predefined colours
"blue",
"red",
"yellow",
"orange",
"black",
"purple",
"green",
"turquoise",
"grey",
"maroon",
"silver",
"white"
]
colour_dict = { # Color mapping (class -> colour)
i: mpl.colors.to_rgb(colour_names[i])
for i in range(C_p + 1)
}
# Create a colormap (optional)
colours_rgb = [colour_dict[i] for i in range(C_p)]
colours = mpl.colors.ListedColormap(colours_rgb)
norm = mpl.colors.BoundaryNorm(np.arange(C_p + 1) - 0.5, C_p)
plt.figure() # If you only want to plot one
plt.imshow(t2, cmap=colours, norm=norm)
cb = plt.colorbar(ticks=np.arange(C_p))
plt.axis("off")
Example with your t1
, t2
and t3
:
fig, axs = plt.subplots(1,3)
axs[0].imshow(t1, cmap = colours, norm=norm)
axs[0].set_title("t1")
axs[0].axis('off')
axs[1].imshow(t2, cmap = colours, norm=norm)
axs[1].set_title("t2")
axs[1].axis('off')
im = axs[2].imshow(t3, cmap = colours, norm=norm)
axs[2].set_title("t3")
axs[2].axis('off')
p0 = axs[0].get_position().get_points().flatten()
p1 = axs[1].get_position().get_points().flatten()
p2 = axs[2].get_position().get_points().flatten()
ax_cbar = fig.add_axes([p0[0], 0.08, p2[0], 0.05])
plt.colorbar(im, cax=ax_cbar, ticks=np.arange(C_p), orientation='horizontal')
fig.tight_layout()
Upvotes: 2
Reputation: 862
Looks like cm.get_cmap
needs to be adjusted to handle all the possible categories/classes in the images. The below code works:
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
t1 = np.arange(9).reshape(3,3)
t2 = t1.copy()
t2[1,1] = 10
t3 = t2.copy()
t3[1,1] = 11
cmap = cm.get_cmap('tab20', 12)
fig, axs = plt.subplots(1,3)
axs[0].imshow(t1, cmap = cmap, vmin = 0, vmax = 11)
axs[1].imshow(t2, cmap = cmap, vmin = 0, vmax = 11)
axs[2].imshow(t3, cmap = cmap, vmin = 0, vmax = 11)
Upvotes: 0