Reputation: 1949
Using another answer, I'm wondering if it is possible to add 3 or more legends? Adapting the code from the author, I could add 4 row labels, but adding the legend is tricky. If I add more row_dendrogram
and col_dendrogram
, they simply do not show independently from the others.
import seaborn as sns
from matplotlib.pyplot import gcf
networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
# Label 1
network_labels = networks.columns.get_level_values("network")
network_pal = sns.cubehelix_palette(network_labels.unique().size, light=.9, dark=.1, reverse=True, start=1, rot=-2)
network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
network_colors = pd.Series(network_labels, index=networks.columns).map(network_lut)
# Label 2
node_labels = networks.columns.get_level_values("node")
node_pal = sns.cubehelix_palette(node_labels.unique().size)
node_lut = dict(zip(map(str, node_labels.unique()), node_pal))
node_colors = pd.Series(node_labels, index=networks.columns).map(node_lut)
# Label 3
lab3_labels = networks.columns.get_level_values("node")
lab3_pal = sns.color_palette("hls", lab3_labels.unique().size)
lab3_lut = dict(zip(map(str, lab3_labels.unique()), lab3_pal))
lab3_colors = pd.Series(lab3_labels, index=networks.columns, name='lab3').map(lab3_lut)
# Label 4
lab4_labels = networks.columns.get_level_values("node")
lab4_pal = sns.color_palette("husl", lab4_labels.unique().size)
lab4_lut = dict(zip(map(str, lab4_labels.unique()), lab4_pal))
lab4_colors = pd.Series(lab4_labels, index=networks.columns, name='lab4').map(lab4_lut)
network_node_colors = pd.DataFrame(network_colors).join(pd.DataFrame(node_colors)).join(pd.DataFrame(lab3_colors)).join(pd.DataFrame(lab4_colors))
g = sns.clustermap(networks.corr(),
row_cluster=False, col_cluster=False,
row_colors = network_node_colors,
col_colors = network_node_colors,
linewidths=0,
xticklabels=False, yticklabels=False,
center=0, cmap="vlag")
# add legends
for label in network_labels.unique():
g.ax_col_dendrogram.bar(0, 0, color=network_lut[label], label=label, linewidth=0);
l1 = g.ax_col_dendrogram.legend(title='Network', loc="center", ncol=5, bbox_to_anchor=(0.47, 0.89), bbox_transform=gcf().transFigure)
for label in node_labels.unique():
g.ax_row_dendrogram.bar(0, 0, color=node_lut[label], label=label, linewidth=0);
l2 = g.ax_row_dendrogram.legend(title='Node', loc="center", ncol=2, bbox_to_anchor=(0.8, 0.89), bbox_transform=gcf().transFigure)
#how to add other row dendrograms here without them overlapping with the existing ones?
plt.show()
Upvotes: 4
Views: 3327
Reputation: 2508
I believe the problem here is that one cannot directly access the axes of the plot. The legend is based on the bar graph, the row which you add. I have found the following workaround which, tbh, is not nice. But working. It follows the classical matplotlib problem of adding an artist to an ax
, you can read more about it in the following posts:
1, 2, 3 and in the docs.
So what I do is that I save the objects of the bar plot when I create them and then later form the legend out of them. The full code is below. But maybe I would recommend contacting the author and raising a question/issue there.
import seaborn as sns
from matplotlib.pyplot import gcf
import matplotlib.pyplot as plt
# fig, axs = plt.subplots()
networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])
# Label 1
network_labels = networks.columns.get_level_values("network")
network_pal = sns.cubehelix_palette(network_labels.unique().size, light=.9, dark=.1, reverse=True, start=1, rot=-2)
network_lut = dict(zip(map(str, network_labels.unique()), network_pal))
network_colors = pd.Series(network_labels, index=networks.columns).map(network_lut)
# Label 2
node_labels = networks.columns.get_level_values("node")
node_pal = sns.cubehelix_palette(node_labels.unique().size)
node_lut = dict(zip(map(str, node_labels.unique()), node_pal))
node_colors = pd.Series(node_labels, index=networks.columns).map(node_lut)
# Label 3
lab3_labels = networks.columns.get_level_values("node")
lab3_pal = sns.color_palette("hls", lab3_labels.unique().size)
lab3_lut = dict(zip(map(str, lab3_labels.unique()), lab3_pal))
lab3_colors = pd.Series(lab3_labels, index=networks.columns, name='lab3').map(lab3_lut)
# Label 4
lab4_labels = networks.columns.get_level_values("node")
lab4_pal = sns.color_palette("husl", lab4_labels.unique().size)
lab4_lut = dict(zip(map(str, lab4_labels.unique()), lab4_pal))
lab4_colors = pd.Series(lab4_labels, index=networks.columns, name='lab4').map(lab4_lut)
network_node_colors = pd.DataFrame(network_colors).join(pd.DataFrame(node_colors)).join(pd.DataFrame(lab3_colors)).join(pd.DataFrame(lab4_colors))
g = sns.clustermap(networks.corr(),
row_cluster=False, col_cluster=False,
row_colors = network_node_colors,
col_colors = network_node_colors,
linewidths=0,
xticklabels=False, yticklabels=False,
center=0, cmap="vlag")
# add legends
for label in network_labels.unique():
g.ax_col_dendrogram.bar(0, 0, color=network_lut[label], label=label, linewidth=0);
l1 = g.ax_col_dendrogram.legend(title='Network', loc="center", ncol=5, bbox_to_anchor=(0.35, 0.89), bbox_transform=gcf().transFigure)
for label in node_labels.unique():
g.ax_row_dendrogram.bar(0, 0, color=node_lut[label], label=label, linewidth=0);
l2 = g.ax_row_dendrogram.legend(title='Node', loc="center", ncol=2, bbox_to_anchor=(0.66, 0.89), bbox_transform=gcf().transFigure)
# create a list for the bar plot patches
xx = []
for label in lab3_labels.unique():
x = g.ax_row_dendrogram.bar(0, 0, color=lab3_lut[label], label=label, linewidth=0)
xx.append(x)
# add the legend
legend3 = plt.legend(xx, lab3_labels.unique(), loc="center", title='lab3', bbox_to_anchor=(.78, 0.89), bbox_transform=gcf().transFigure)
# create a list for the bar plot patches
yy = []
for label in lab4_labels.unique():
y = g.ax_row_dendrogram.bar(0, 0, color=lab4_lut[label], label=label, linewidth=0)
yy.append(y)
# add the second legend
legend4 = plt.legend(yy, lab4_labels.unique(), loc="center", title='lab4', ncol=2, bbox_to_anchor=(.9, 0.89), bbox_transform=gcf().transFigure)
plt.gca().add_artist(legend3)
Upvotes: 4