Romaker
Romaker

Reputation: 171

How to change plot_confusion_matrix default figure size in sklearn.metrics package

I tried to plot confusion matrix with Jupyter notebook using sklearn.metrics.plot_confusion_matrix package, but the default figure size is a little bit small. I have added plt.figure(figsize=(20, 20)) before plotting, but the figure size did not change with output text 'Figure size 1440x1440 with 0 Axes'. How can I change the figure size?

%matplotlib inline
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import plot_confusion_matrix
from matplotlib import pyplot as plt

plt.figure(figsize=(20, 20))
clf = GradientBoostingClassifier(random_state=42)
clf.fit(X_train, y_train)
plot_confusion_matrix(clf, X_test, y_test, cmap=plt.cm.Blues)
plt.title('Confusion matrix')
plt.show()

just like this image

Upvotes: 17

Views: 29780

Answers (3)

dom free
dom free

Reputation: 1235

I use set_figwidth and set_figheight to specify the figure size:

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay  
import matplotlib.pyplot as plt 
disp = ConfusionMatrixDisplay.from_predictions(
                  [0,1,1,0,1], 
                  [0,1,0,1,0], 
                  labels=[1,0],
                  cmap=plt.cm.Blues,
                  display_labels=['Good','Bad'], 
                  values_format='',  
) 
fig = disp.ax_.get_figure() 
fig.set_figwidth(3)
fig.set_figheight(3)  

Upvotes: 3

Real Uniquee
Real Uniquee

Reputation: 531

ConfusionMatrixDisplay offers more control and flexibility when visualizing a confusion matrix than plot_confusion_matrix. For more info: docs

from sklearn.metrics import ConfusionMatrixDisplay  
y_true = [0,1,1,0,1]
y_pred = [0,1,0,1,0]
labels = ['Good','Bad'] # 0: Good and 1: Bad
disp = ConfusionMatrixDisplay.from_predictions(
                                              y_true, 
                                              y_pred, 
                                              display_labels=labels, 
                                              cmap=plt.cm.Blues
                                              ) 
fig = disp.figure_
fig.set_figwidth(10)
fig.set_figheight(10) 
fig.suptitle('Plot of confusion matrix')

Upvotes: 2

Hovanes Gasparian
Hovanes Gasparian

Reputation: 382

I don't know why BigBen posted that as a comment, rather than an answer, but I almost missed seeing it. Here it is as an answer, so future onlookers don't make the same mistake I almost made!

fig, ax = plt.subplots(figsize=(10, 10))
plot_confusion_matrix(your_model, X_test, y_test, ax=ax)

Upvotes: 25

Related Questions