Liam Sfee
Liam Sfee

Reputation: 63

I want to add extra space between the columns of matplotlib heatmap, how do I do it?

Heatmap with linewidth=18

I am trying to add extra space between the columns of the heatmap. Right now, I have set the linewidths of the heatmap to 18, which leaves the even gaps among the "tiles", but I'd like the gap between the two columns to be wider than the horizontal gaps.

I looked into matplitlib.collections source code and found:

 def set_linewidth(self, lw):
        """
        Set the linewidth(s) for the collection.  *lw* can be a scalar
        or a sequence; if it is a sequence the patches will cycle
        through the sequence

        Parameters
        ----------
        lw : float or sequence of floats
        """
        if lw is None:
            lw = mpl.rcParams['patch.linewidth']
            if lw is None:
                lw = mpl.rcParams['lines.linewidth']
        # get the un-scaled/broadcast lw
        self._us_lw = np.atleast_1d(np.asarray(lw))

        # scale all of the dash patterns.
        self._linewidths, self._linestyles = self._bcast_lwls(
            self._us_lw, self._us_linestyles)
        self.stale = True

It seems like I am getting nowhere with this function though. I tried to give it a list of widths instead of a variable, but the linewidth only registers the first element in the list.

EDIT: Many thanks to JohanC for the solution. For those who asked, this was the code I had that resulted in the heatmap in the pic.

self.axes[i].pcolormesh(data,  cmap = "YlGnBu", edgecolor=BG_COLOUR, \
      linewidths=18, vmin=0.0, vmax=1.0)

The following code is the application of JohanC's solution

self.axes[sel].pcolormesh(data,  cmap = "YlGnBu", vmin=0.0, vmax=1.0)
for j in range(data.shape[0] + 1):
     self.axes[sel].axhline(j, color=BG_COLOUR, lw=20)
for j in range(data.shape[1] + 1):
     self.axes[sel].axvline(j, color=BG_COLOUR, lw=60)

enter image description here

Upvotes: 2

Views: 2588

Answers (1)

JohanC
JohanC

Reputation: 80459

You could explicitly draw horizontal and vertical lines with different widths:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

data = np.random.rand(5, 2)
ax = sns.heatmap(data)
for i in range(data.shape[0] + 1):
    ax.axhline(i, color='white', lw=20)
for i in range(data.shape[1] + 1):
    ax.axvline(i, color='white', lw=60)
plt.show()

example plot

Upvotes: 3

Related Questions