Jan D.M.
Jan D.M.

Reputation: 2682

Why has subplot of matplotlib not the same size?

I am building a classifier in Python and would like to evaluate it. I already built some code that is plotting a few different metrics per class. Plot_metric() is a self defined function.

gs = gridspec.GridSpec(1, 5, width_ratios=[5, 1, 1, 1, 1])
ax1 = plt.subplot(gs[0])
ax2 = plt.subplot(gs[1])
ax3 = plt.subplot(gs[2])
ax4 = plt.subplot(gs[3])
ax5 = plt.subplot(gs[4])

skplot.metrics.plot_confusion_matrix(confusion_y_test, confusion_y_pred, ax=ax1, labels=[0,1,2,3,4])
plot_metric(accuracy_per_class_array, "Class accuracy", ax2, cmap=plt.cm.Blues)
plot_metric(precision_per_class_array, "Class precision", ax3, cmap=plt.cm.Blues)
plot_metric(recall_per_class_array, "Class recall", ax4, cmap=plt.cm.Blues)
plot_metric(f1_per_class_Array, "Class F1", ax5, cmap=plt.cm.Blues)

I use every axis to plot a different metric and the confusion matrix as you can see in the picture.

But why is the last column smaller than the others? I would like it to be the same height as the other plots.

enter image description here

Update:

The code of the plot_metric():

def plot_metric(metric_array, metric_name, axis, display_labels=None, cmap="viridis"):
"""
plotted_cm : instance of `ConfusionMatrixDisplay`
    Result of `sklearn.metrics.plot_confusion_matrix`
axis : matplotlib `AxesSubplot`
    Result of `fig, (ax1, ax2) = plt.subplots(1, 2)`
display_labels : list of labels or None
    Human-readable class names
cmap : colormap, optional
    Optional colormap
"""

n_classes = len(metric_array)

if display_labels is None:
    labels = np.arange(n_classes)
else:
    labels = display_labels

axis.imshow(
    np.array(metric_array).reshape(n_classes, 1),
    interpolation="nearest",
    cmap=cmap,
)

for i, value in enumerate(metric_array):
    axis.text(0, i, format(value, ".2g"), ha="center", va="center")


axis.set(
    yticks=np.arange(len(metric_array)),
    xlabel=metric_name,
    yticklabels=labels,
)
axis.tick_params(
    axis="x", bottom=False, labelbottom=False,
)
axis.set_ylim((len(metric_array) - 0.5, -0.5))

Upvotes: 2

Views: 166

Answers (0)

Related Questions