RDoc
RDoc

Reputation: 346

How to add a colorcode to the yticklabels on a seaborn heatmap?

I have a matrix of gene expressions vs. cells and would like to display them as a heatmap, which in itself isn't an issue. However, displaying all of the genes as yticklabels would be far too chaotic and visually unappealing. Therefore, I annotated each of the genes as belonging to a particular functional group and would like to represent each functional group as a color and have their colors shown on the heatmap, in the same order as the genes appear. Just to clarify, I would not like to group them by colors which I believe you could do using seaborn clustermap.

As such, so far I have a pandas dataframe that contains a multiindex of genes and their respective functional group, and cells.

I've searched extensively on Stackoverflow and Google for answers, without any luck. This is my first attempt at anything of the kind, so unfortunately I do not know where exactly to start.

So for sheer simplicity let's say you have the following dataframe:

import seaborn as sns
import numpy as np
import pandas as pd

data=pd.DataFrame(np.array([(0,1,2),(4,5,6),(7,8,9)]), columns=['C1','C2','C3'], index=pd.MultiIndex.from_arrays([['Gene1','Gene2','Gene3'],['A','B','A']]))

That would yield the following:

           C1  C2  C3
Gene1   A   0   1   2
Gene2   B   4   5   6
Gene3   A   7   8   9

Now, I can simply call sns.heatmap(data)to generate the heatmap. However, how can I customize it such that I get colors representing A & B rather than Gene1, Gene2, Gene3 as yticklabels? For instance, say A is blue and B is green, I want it to show the yticklabels (from top->bottom) as blue, green, blue.

Thanks a lot in advance.

Upvotes: 0

Views: 152

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339290

Here is a possibly solution to create a new axes left to the heatmap, which shows another heatmap based on the values of the values of the second Multiindex level.

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd

data=pd.DataFrame(np.array([(0,1,2),(4,5,6),(7,8,9)]), 
                  columns=['C1','C2','C3'], 
                  index=pd.MultiIndex.from_arrays([['Gene1','Gene2','Gene3'],['A','B','A']]))


cats = data.index.to_frame().set_index(0)
u, inv = np.unique(cats.values, return_inverse=True)

colors = ["navy", "limegreen", "gold"]
assert(len(u) <= len(colors))

cmap = mcolors.ListedColormap(colors)
norm = mcolors.BoundaryNorm(np.arange(len(u)+1)-.5, len(u))

fig, (sax, hax) = plt.subplots(ncols=2, sharey=True,
                               gridspec_kw=dict(width_ratios=[1, data.shape[1]]))

im = sax.imshow(np.atleast_2d(inv).T, cmap=cmap, norm=norm)
hax.imshow(data.values, cmap="Greys")

sax.set_yticks(np.arange(len(cats)))
sax.set_yticklabels(cats.index)
sax.tick_params(bottom=False, labelbottom=False)

hax.set_xticks(np.arange(len(data.columns)))
hax.set_xticklabels(data.columns)

cbar = fig.colorbar(im, cax = fig.add_axes([.125, .08, .1, .04]), 
                    orientation="horizontal", ticks=np.arange(len(u)))
cbar.set_ticklabels(u)

plt.show()

enter image description here

Upvotes: 1

Related Questions