Reputation: 53
I have tried around 15 different methods for setting the y-label for this simple confusion matrix visualization code. Currently, I have resorted to just directly labeling the rows as 'Predicted Positive' and 'Predicted Negative' but I would prefer to have 'Predicted' outside the table like I do with 'Actual'. Very confused what's going wrong. I'm assuming it has something to do with the fact that I'm plotting a table. Removing the row labels does not fix the issue. Thanks in advance!
def plot_conf_mat(data, model_name):
'''
Plot a confusion matrix based on the array data.
Expected: 2x2 matrix of form
[[TP, FP],
[FN, TN]].
Outputs a simple colored confusion matrix table
'''
#set fontsizes
SMALL_SIZE = 20
MEDIUM_SIZE = 25
BIGGER_SIZE = 30
plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
# Prepare table
columns = ('Positive', 'Negative')
rows = ('Predicted\nPositive', 'Predicted\nNegative')
cell_text = data
# Add a table at the bottom of the axes
colors = [["tab:green","tab:red"],[ "tab:red","tab:grey"]]
fig, ax = plt.subplots(figsize = (6,5))
ax.axis('tight')
ax.axis('off')
the_table = ax.table(cellText=cell_text,cellColours=colors,
colLabels=columns, rowLabels = rows, loc='center')
the_table.scale(2,5)
the_table.set_fontsize(20) #apparently it doesnt adhere to plt.rc??
ax.set_title(f'{model_name} Confusion Matrix: \n\nActual')
ax.set_ylabel('Predicted') #doesn't work!!
fig.savefig(f"{model_name}_conf_mat.pdf", bbox_inches = 'tight')
plt.show()
Out (model name redacted):
Upvotes: 0
Views: 727
Reputation: 9967
Firstly, did you know that there is a sklearn.metrics visualization option called ConfusionMatrixDisplay which might do what you are looking for. Do see if that helps.
For the table itself, matplotlib table is used to add a table to an axis. It usually contains a plot along with the table. As you only need a table, you are hiding the plot. If you comment out the line ax.axis('off')
, you will see the borders of the plot. The ax.set_ylabel()
will not work for this reason, as it is the label for the plot, which is hidden.
A simple workaround is to add text at the right place. Adding this instead of the set_ylabel() did the trick. You may need to fine tune the x and y coordinates.
plt.text(-0.155, -0.0275,'Predicted', fontsize= SMALL_SIZE, rotation=90)
Upvotes: 2