Jose Lago
Jose Lago

Reputation: 65

seaborn heatmap: y axis ticks and annot

I am trying to plot a heatmap for a simpe confusion matrix. My only issue is that the ticks in the y axis and the annotations inside each field aren't aligning to the center.

I have tried to use other responses for similar issues but I did't manage to get it right... could you please help?

Thanks in advance!

Code:

fig = plt.figure(figsize=[7,7])
ax = fig.add_subplot(1, 1, 1)
sns.heatmap(confusion_matrix,annot=True,cbar=False,cmap='Blues')
plt.ylabel('Actual Values')
plt.xlabel('Predicted Values')
plt.title('Accuracy Score: {0}'.format(round(accuracy,2), size = 15))
plt.tight_layout()
plt.show()

enter image description here

as requested in the comments, here is the full code so you can see the origin of the data in the heatmap:

import numpy as np
import pandas as pd 
from sklearn import datasets 
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn import metrics
import seaborn as sns
import matplotlib.pyplot as plt 

    bunch = datasets.load_breast_cancer()

    def bunch_to_df(bunch):

        data = np.c_[bunch.data, bunch.target]
        columns = np.append(bunch.feature_names, ["target"])
        return pd.DataFrame(data, columns=columns)

    df = bunch_to_df(bunch)

    x = df[['mean area', 'mean texture']]
    y = df.loc[:,['target']].values

    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)

    x_train = sc.fit_transform(x_train)
    x_test = sc.transform(x_test)

    logisticRegr = LogisticRegression()

    logisticRegr.fit(x_train, y_train.ravel())

    predictions = logisticRegr.predict(x_test)
    accuracy = logisticRegr.score(x_test, y_test.ravel())
    confusion_matrix = metrics.confusion_matrix(y_test, predictions)

    fig = plt.figure(figsize=[7,7])
    ax = fig.add_subplot(1, 1, 1)
    sns.heatmap(confusion_matrix,annot=True,cbar=False,cmap='Blues')
    plt.ylabel('Actual Values')
    plt.xlabel('Predicted Values')
    plt.title('Accuracy Score: {0}'.format(round(accuracy,2), size = 15))
    plt.tight_layout()
    plt.show()

Upvotes: 2

Views: 1506

Answers (1)

Sverre
Sverre

Reputation: 54

I believe this is a bug in the current version of matplotlib. This post may provide an answer.

You could try manually setting the axis limits by using ax.set_ylim(3.0, 0) or reverting the matplotlib version to 3.1.0.

if that doesn't work, you could install the latest version from Github. Look at the 'Installing from source' section for instructions.

Upvotes: 2

Related Questions