Daniel Frees
Daniel Frees

Reputation: 53

Why won't matplotlib display the y-axis label on my tables?

I have tried around 15 different methods for setting the y-label for this simple confusion matrix visualization code. Currently, I have resorted to just directly labeling the rows as 'Predicted Positive' and 'Predicted Negative' but I would prefer to have 'Predicted' outside the table like I do with 'Actual'. Very confused what's going wrong. I'm assuming it has something to do with the fact that I'm plotting a table. Removing the row labels does not fix the issue. Thanks in advance!

def plot_conf_mat(data, model_name):
    '''
    Plot a confusion matrix based on the array data.
    
    Expected: 2x2 matrix of form
    [[TP, FP],
     [FN, TN]].
     
    Outputs a simple colored confusion matrix table
    '''
    
    #set fontsizes
    SMALL_SIZE = 20
    MEDIUM_SIZE = 25
    BIGGER_SIZE = 30

    plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
    
    # Prepare table
    columns = ('Positive', 'Negative')
    rows = ('Predicted\nPositive', 'Predicted\nNegative')
    cell_text = data
    # Add a table at the bottom of the axes
    colors = [["tab:green","tab:red"],[ "tab:red","tab:grey"]]
    
    fig, ax = plt.subplots(figsize = (6,5))
    ax.axis('tight')
    ax.axis('off')
    the_table = ax.table(cellText=cell_text,cellColours=colors,
                         colLabels=columns, rowLabels = rows, loc='center')

    the_table.scale(2,5)
    the_table.set_fontsize(20) #apparently it doesnt adhere to plt.rc??
    ax.set_title(f'{model_name} Confusion Matrix: \n\nActual')
    
    ax.set_ylabel('Predicted') #doesn't work!!
        
    fig.savefig(f"{model_name}_conf_mat.pdf", bbox_inches = 'tight')
    
    plt.show()

Out (model name redacted):

confusion matrix output

Upvotes: 0

Views: 727

Answers (1)

Redox
Redox

Reputation: 9967

Firstly, did you know that there is a sklearn.metrics visualization option called ConfusionMatrixDisplay which might do what you are looking for. Do see if that helps.

For the table itself, matplotlib table is used to add a table to an axis. It usually contains a plot along with the table. As you only need a table, you are hiding the plot. If you comment out the line ax.axis('off'), you will see the borders of the plot. The ax.set_ylabel() will not work for this reason, as it is the label for the plot, which is hidden.

A simple workaround is to add text at the right place. Adding this instead of the set_ylabel() did the trick. You may need to fine tune the x and y coordinates.

plt.text(-0.155, -0.0275,'Predicted', fontsize= SMALL_SIZE, rotation=90)

enter image description here

Upvotes: 2

Related Questions