Petar
Petar

Reputation: 365

Plotly clustered heatmap (with dendrogram)/Python

I am trying to create a clustered heatmap (with a dendrogram) using plotly in Python. The one they have made in their website does not scale well, I have come to various solutions, but most of them are in R or JavaScript. I am trying to create a heatmap with a dendrogram from the left side of the heatmap only, showing clusters across the y axis (from the hierarchical clustering). A really good looking example is this one: https://chart-studio.plotly.com/~jackp/6748. My purpose is to create something like this, but only with the left-side dendrogram. If someone can implement something like this in Python, I will be really grateful!

Let the data be X = np.random.randint(0, 10, size=(120, 10))

Upvotes: 4

Views: 6450

Answers (3)

Andrew R
Andrew R

Reputation: 1

can also use seabornes clustermap https://seaborn.pydata.org/generated/seaborn.clustermap.html

Upvotes: 0

Madcat
Madcat

Reputation: 379

  1. The simplest solution to this problem is to use dash_bio.Clustergram function in dash_bio package.
import pandas as pd
import dash_bio as dashbio

X = np.random.randint(0, 10, size=(120, 10))

dashbio.Clustergram(
    data=X,
    # row_labels=rows,
    # column_labels=columns,
    cluster='row',
    color_threshold={
        'row': 250,
        'col': 700
    },
    height=800,
    width=700,
    color_map= [
        [0.0, '#636EFA'],
        [0.25, '#AB63FA'],
        [0.5, '#FFFFFF'],
        [0.75, '#E763FA'],
        [1.0, '#EF553B']
    ]
)

enter image description here

  1. An more laborious solution is to use the plot function plotly.figure_factory.create_dendrogram combined with plotly.graph_objects.Heatmap as in plotly document the example is not a dendrogram heat map but rather a pair wised distance heat map, you can use the two function to create dendrogram heat map though.

Upvotes: 1

vestland
vestland

Reputation: 61104

The following suggestion draws on elements from both Dendrograms in Python and chart-studio.plotly.com/~jackp. This particular plot uses your data X = np.random.randint(0, 10, size=(120, 10)). One thing that the linked approaches had in common, was, in my opinion, that the datasets and data munging procedures were a bit messy. So I decided to build the following figure on a pandas dataframe with df = pd.DataFrame(X) to hopefully make everything a bit clearer

Plot

enter image description here

Complete code

import plotly.graph_objects as go
import plotly.figure_factory as ff

import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
import random
import string

X = np.random.randint(0, 10, size=(120, 10))
df = pd.DataFrame(X)

# Initialize figure by creating upper dendrogram
fig = ff.create_dendrogram(df.values, orientation='bottom')
fig.for_each_trace(lambda trace: trace.update(visible=False))

for i in range(len(fig['data'])):
    fig['data'][i]['yaxis'] = 'y2'

# Create Side Dendrogram
# dendro_side = ff.create_dendrogram(X, orientation='right', labels = labels)
dendro_side = ff.create_dendrogram(X, orientation='right')
for i in range(len(dendro_side['data'])):
    dendro_side['data'][i]['xaxis'] = 'x2'

# Add Side Dendrogram Data to Figure
for data in dendro_side['data']:
    fig.add_trace(data)

# Create Heatmap
dendro_leaves = dendro_side['layout']['yaxis']['ticktext']
dendro_leaves = list(map(int, dendro_leaves))
data_dist = pdist(df.values)
heat_data = squareform(data_dist)
heat_data = heat_data[dendro_leaves,:]
heat_data = heat_data[:,dendro_leaves]

heatmap = [
    go.Heatmap(
        x = dendro_leaves,
        y = dendro_leaves,
        z = heat_data,
        colorscale = 'Blues'
    )
]

heatmap[0]['x'] = fig['layout']['xaxis']['tickvals']
heatmap[0]['y'] = dendro_side['layout']['yaxis']['tickvals']

# Add Heatmap Data to Figure
for data in heatmap:
    fig.add_trace(data)

# Edit Layout
fig.update_layout({'width':800, 'height':800,
                         'showlegend':False, 'hovermode': 'closest',
                         })
# Edit xaxis
fig.update_layout(xaxis={'domain': [.15, 1],
                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'ticks':""})
# Edit xaxis2
fig.update_layout(xaxis2={'domain': [0, .15],
                                   'mirror': False,
                                   'showgrid': False,
                                   'showline': False,
                                   'zeroline': False,
                                   'showticklabels': False,
                                   'ticks':""})

# Edit yaxis
fig.update_layout(yaxis={'domain': [0, 1],
                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'showticklabels': False,
                                  'ticks': ""
                        })
# # Edit yaxis2
fig.update_layout(yaxis2={'domain':[.825, .975],
                                   'mirror': False,
                                   'showgrid': False,
                                   'showline': False,
                                   'zeroline': False,
                                   'showticklabels': False,
                                   'ticks':""})

fig.update_layout(paper_bgcolor="rgba(0,0,0,0)",
                  plot_bgcolor="rgba(0,0,0,0)",
                  xaxis_tickfont = dict(color = 'rgba(0,0,0,0)'))

fig.show()

Upvotes: 4

Related Questions