AI_NA
AI_NA

Reputation: 346

How change the color of boxes in confusion matrix using sklearn?

Here is my code snippet to produce confusion matrix: I am wondering how can I change the color of boxes in confusion matrix for those boxes which are not located in diagonal same as heatmap using sklearn.

nb_classes = 15    
confusion_matrix = torch.zeros(nb_classes, nb_classes)

with torch.no_grad():
    for i, (inputs, target, classes, im_path) in enumerate(dataLoaders['test']):
        
        inputs = inputs.to(device)
        target = target.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        for t, p in zip(target.view(-1), preds.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
                
num_classes = 15
class_names = ['A2CH', 'A3CH', 'A4CH_LV', 'A4CH_RV', 'A5CH', 'Apical_MV_LA_IAS',
                 'OTHER', 'PLAX_TV', 'PLAX_full', 'PLAX_valves', 'PSAX_AV', 'PSAX_LV',
                 'Subcostal_IVC', 'Subcostal_heart', 'Suprasternal']                
    
plt.figure()
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues)

tick_marks = numpy.arange(num_classes)
classNames = class_names
   
thresh = confusion_matrix.max() / 2.
for i in range(confusion_matrix.shape[0]):
    for j in range(confusion_matrix.shape[1]):
        plt.text(j, i, format(confusion_matrix[i, j]),
                ha="center", va="center",
                color="white" if  confusion_matrix[i, j] == 0 or confusion_matrix[i, j] > thresh else "black") 
plt.tight_layout()
plt.colorbar()
return plt
plt.show()   

enter image description here

Upvotes: 5

Views: 45776

Answers (2)

Atif Rizwan
Atif Rizwan

Reputation: 685

Use heatmap to plot confusion matrix

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3], 
    [3,31,0,0,0,0,0,0,0,0,0], 
    [0,4,41,0,0,0,0,0,0,0,1], 
    [0,1,0,30,0,6,0,0,0,0,1], 
    [0,0,0,0,38,10,0,0,0,0,0], 
    [0,0,0,3,1,39,0,0,0,0,4], 
    [0,2,2,0,4,1,31,0,0,0,2],
    [0,1,0,0,0,0,0,36,0,2,0], 
    [0,0,0,0,0,0,1,5,37,5,1], 
    [3,0,0,0,0,0,0,0,0,39,0], 
    [0,0,0,0,0,0,0,0,0,0]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
              columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True,cmap="OrRd")

heatmap accept an extra argument cmap to change the color of matrix. These are some possible values for cmap.

cmap = [Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, 
BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, 
Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, 
Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, 
PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, 
RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, 
Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Wistia, Wistia_r, YlGn, YlGnBu, 
YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn,
autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cividis, 
cividis_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, 
cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r,
gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, 
gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, 
gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, icefire, icefire_r, inferno, 
inferno_r, jet, jet_r, magma, magma_r, mako, mako_r, nipy_spectral, nipy_spectral_r,
ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, 
rocket, rocket_r, seismic, seismic_r, spring, spring_r, summer, summer_r, tab10, 
tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, 
viridis, viridis_r, vlag, vlag_r, winter, winter_r]

cmap = "OrRd" cmap = "OrRd"

cmap = "Greens_r" cmap = "Greens_r" cmap = "OrRd_r" cmap = "OrRd_r"

Upvotes: 13

Harun Ismail
Harun Ismail

Reputation: 109

def plot_confusion_matrix(y_true, y_pred, classes,
                      normalize=False,
                      title=None,
                      cmap=plt.cm.Blues):

you can change a name in cmap=plt.cm.Blues as the color you want such as green, red, orange, etc. Don't forget to add s in every word of colors. In addition, there are two default forms of each confusion matrix color. For example, it is green.

  1. Greens. it is for green color in diagonal line.
  2. Greens_r. It is for green color outside of diagonal line.

Upvotes: 1

Related Questions