Pat
Pat

Reputation: 1339

matplotlib seaborn long rownames affect other subplots' axes

I want to plot a table and I really like seaborn. So I wrote a simple function that returns a seaborn heatmap and fills the cells with some grey color. My problem is, when I use that table in a subplot, other plots along the vertivcal axes are getting shrunk to the width of the axis that I plot the seaborn heatmap on.

import seaborn as sns
from matplotlib.colors import ListedColormap
import numpy as np
import pandas as pd
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

def plotTable(tableDf, ax=None):
    s = len(tableDf), len(tableDf.columns)
    dataZeros = np.zeros(s)
    dataTableFake = pd.DataFrame(dataZeros, index=tableDf.index, columns=tableDf.columns)
    dataTable = tableDf.values.astype(str)

    col = ["#dbdada"]
    my_cmap = ListedColormap(col)
    ax = sns.heatmap(dataTableFake, cmap=my_cmap, annot_kws={"size": 8}, cbar=False, linewidths=0.5,
                     square=False, ax=ax, annot = dataTable, fmt = '')
    ax.xaxis.tick_top()
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)


table = pd.DataFrame(dict(value01=1.01,
                          value02=2.02,
                          value03=3.03),
                          index=["This is a pretty long index and I dont't want it to shrink the other subplots"])

gs = gridspec.GridSpec(2, 1)

ax1 = plt.subplot(gs[0, 0])
ax2 = plt.subplot(gs[1, 0])

plotTable(tableDf=table, ax=ax1)
plt.tight_layout()

creates the following plot:

enter image description here

I don't want the second axis (i.e. ax2) to be affected by the length of the rownames in ax1. I hope this makes clear what I mean:

enter image description here

Upvotes: 1

Views: 317

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339745

You can define a gridspec over 2 rows and 3 columns and place the "table" in the rightmost subplot in the first row. The other plot can then span over all 3 subplot positions in the second row.

gs = gridspec.GridSpec(2, 3)
ax1 = plt.subplot(gs[0, 2])
ax2 = plt.subplot(gs[1, :])

Finally, don't call tight_layout as this messes everything up. If you need to adjust the spacings, manually set them using plt.subplots_adjust(...)

enter image description here

Upvotes: 1

Related Questions