BML91
BML91

Reputation: 3190

Setting properly aligned axis labels on matplotlib pcolor plot

I have a simple matplotlib pcolor plot which can be reproduced with the following MWE:

import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
import numpy as np

test_data = np.array([[-0.00842278, -0.03332517, -0.01478557, -0.00275494],
       [ 0.16338327,  0.08383871,  0.03093892,  0.03380778],
       [-0.02246485, -0.1490697 , -0.14918824, -0.12745594],
       [ 0.02477743,  0.1537171 ,  0.13111042,  0.11950057],
       [-0.15408288, -0.04697411, -0.0068787 , -0.01576426],
       [ 0.03508095,  0.19434805,  0.13647802,  0.11276903],
       [-0.16683297,  0.05313956,  0.0283734 ,  0.01179509],
       [-0.08839198, -0.02095752, -0.00573671,  0.00360559],
       [ 0.15476156, -0.06324123, -0.04798161, -0.03844384],
       [-0.056892  , -0.09804484, -0.09506561, -0.08506755],
       [ 0.2318552 , -0.02209629, -0.04530164, -0.02950514],
       [-0.11914883,  0.00965362, -0.02431899, -0.0203009 ],
       [ 0.16025558,  0.02234824, -0.01480751, -0.01487853],
       [ 0.17345419, -0.04348332, -0.07625766, -0.05771962]])

test_df = pd.DataFrame(1 - abs(test_data))
test_df.columns = ['3', '6', '9', '12']
test_df.index = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '15', '20', '25', '30']
plt.pcolor(test_df, cmap=plt.cm.RdYlGn, vmin=0, vmax=1)
plt.show()

Which produces this:

plot without aligned axis labels

As can be seen by the above the axis labels aren't correct nor are they aligned correctly with the coloured rectangles of the plot.

I can somewhat create the intended axis labelling on the x axis using the following code:

ax = plt.gca()
labels = [u'', u'3', u'', u'6', u'', u'9', u'', u'12', u'']
ax.set_xticklabels(labels)

Which produces this:

second plot corrected x axis

My problem is that I can't reproduce this on the y axis as the labels aren't in line with the centre of the rectangles.

Is there a way of making the x and y axis labels correct as stated in the dataframe titles and index? Whilst ensuring the labels are centred on the rectangles, not on the edges.

Upvotes: 3

Views: 5881

Answers (2)

CoMartel
CoMartel

Reputation: 3591

I found another solution that I think is more straightforward, using sns:

import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
%matplotlib # magic command for sns to use matplotlib

test_data = np.array([[-0.00842278, -0.03332517, -0.01478557, -0.00275494],
       [ 0.16338327,  0.08383871,  0.03093892,  0.03380778],
       [-0.02246485, -0.1490697 , -0.14918824, -0.12745594],
       [ 0.02477743,  0.1537171 ,  0.13111042,  0.11950057],
       [-0.15408288, -0.04697411, -0.0068787 , -0.01576426],
       [ 0.03508095,  0.19434805,  0.13647802,  0.11276903],
       [-0.16683297,  0.05313956,  0.0283734 ,  0.01179509],
       [-0.08839198, -0.02095752, -0.00573671,  0.00360559],
       [ 0.15476156, -0.06324123, -0.04798161, -0.03844384],
       [-0.056892  , -0.09804484, -0.09506561, -0.08506755],
       [ 0.2318552 , -0.02209629, -0.04530164, -0.02950514],
       [-0.11914883,  0.00965362, -0.02431899, -0.0203009 ],
       [ 0.16025558,  0.02234824, -0.01480751, -0.01487853],
       [ 0.17345419, -0.04348332, -0.07625766, -0.05771962]])


Cols = ['3', '6', '9', '12']
Index = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '15', '20', '25', '30']
test_df = pd.DataFrame(1 - abs(test_data),index=Index, columns=Cols)
sns.heatmap(test_df,cmap=plt.cm.RdYlGn,vmin=0,vmax=1,cbar=True)

As you will notice, it plot the exact same thing, but directly use index and columns as labels.

Another difference is that it uses the original indexes of the DataFrame, so your graph is not upside-down like with the matplotlib solution.

Note that on my computer, I have to enlarge the window to see the colorbar, otherwise it is hidden.

Upvotes: 1

tmdavison
tmdavison

Reputation: 69166

Its not great to do it like this (you are decoupling the tick labels from the data), but you can do this:

fig,ax = plt.subplots()

ax.pcolor(test_df, cmap=plt.cm.RdYlGn, vmin=0, vmax=1)

ax.set_yticks(np.arange(len(test_df.index))+0.5)
ax.set_yticklabels(test_df.index)

ax.set_xticks(np.arange(len(test_df.columns))+0.5)
ax.set_xticklabels(test_df.columns)

We are setting the ticks to every 0.5, 1.5, 2.5 (to centre them), etc., and then setting the tick labels from your dataframe index and columns.

enter image description here

Upvotes: 4

Related Questions