S.V
S.V

Reputation: 2793

making seaborn heatmap from DataFrame plot to be aware of data ranges

How can one make seaborn heatmap (created from a pandas DataFrame plot) to be aware of the data ranges? I.e. when I hover the mouse pointer over the plot, I can see in the bottom right corner of the plot window "x= y=", while I want to see coordinates of the point on the plot I am hovering over (for example, "x=25.6, y=3.3"), assuming, of course, that the input DataFrame contains a 2D histogram with equal size bins along each axis.

Alternatively, maybe I could create such plot in a different way to achieve the same effect? For example, with ax.hist2d I get it out of the box, but I want to be able to compute with custom code content of each bin and make it effectively a heatmap plot (with color coding of the bin contents).

import numpy as np 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
Index = [ 1.0,  2.0,  3.0,  4.0,  5.0]
Cols  = [10.0, 20.0, 30.0, 40.0, 50.0]
df = pd.DataFrame(abs(np.random.randn(5, 5)), 
    index=Index, columns=Cols)
plt.close(1)
fig,ax = plt.subplots(num=1)
sns.heatmap(df, annot=True)
plt.show(block=False)

Thank you for your help!

Upvotes: 0

Views: 842

Answers (2)

from matplotlib.ticker import FixedFormatter


class CustomFormatter(FixedFormatter):
    def __init__(self, old):
        super().__init__(old.seq)

    def __call__(self, x, pos=None):
        return self.seq[abs(self.locs - x).argmin()]

plt.gca().xaxis.set_major_formatter(CustomFormatter(plt.gca().xaxis.get_major_formatter()))
plt.gca().yaxis.set_major_formatter(CustomFormatter(plt.gca().yaxis.get_major_formatter()))

Upvotes: 0

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339350

If you replace sns.heatmap(...) by ax.imshow(..), you're close to what you need. You can then set the extent of the image to the data range you need.

import numpy as np; np.random.seed(42)
import pandas as pd
import matplotlib.pyplot as plt

Index = [ 1.0,  2.0,  3.0,  4.0,  5.0]
Cols  = [10.0, 20.0, 30.0, 40.0, 50.0]
df = pd.DataFrame(abs(np.random.randn(5, 5)), 
    index=Index, columns=Cols)
plt.close(1)
fig,ax = plt.subplots(num=1)
dx = np.diff(df.columns)[0]/2
dy = np.diff(df.index)[0]/2
extent = [df.columns.min()-dx, df.columns.max()+dx,
          df.index.min()-dy, df.index.max()+dy] 
ax.imshow(df, extent=extent, aspect="auto")

plt.show()

enter image description here

Upvotes: 1

Related Questions