zwer
zwer

Reputation: 25829

Matplotlib - Dynamic (bar) chart height based on data?

After struggling with matplotlib for longer than I'd like to admit by trying to do something that's a breeze in pretty much any other plotting library I ever used, I've decided to ask the Stackiverse for some insight. In a nutshell, what I need is to create multiple horizontal bar charts, all sharing the x axis, with different number of values on the y axis and with all the bars having the same height, while the charts themselves would adjust to the amount of entries. A simplified data structure of what I need to plot would be something like:

[
    {"name": "Category 1", "entries": [
        {"name": "Entry 1", "value": 5},
        {"name": "Entry 2", "value": 2},
    ]},
    {"name": "Category 2", "entries": [
        {"name": "Entry 1", "value": 1},
    ]},
    {"name": "Category 3", "entries": [
        {"name": "Entry 1", "value": 1},
        {"name": "Entry 2", "value": 10},
        {"name": "Entry 3", "value": 4},
    ]},    
]

And the closesest I got to what I'd like as a result is using:

import matplotlib.pyplot as plt

def plot_data(data):
    total_categories = len(data)  # holds how many charts to create
    max_values = 1  # holds the maximum number of bars to create
    for category in data:
        max_values = max(max_values, len(category["entries"]))
    fig = plt.figure(1)
    ax = None
    for index, category in enumerate(data):
        entries = []
        values = []
        for entry in category["entries"]:
            entries.append(entry["name"])
            values.append(entry["value"])
        if not entries:
            continue  # do not create empty charts
        y_ticks = range(1, len(entries) + 1)
        ax = fig.add_subplot(total_categories, 1, index + 1, sharex=ax)
        ax.barh(y_ticks, values)
        ax.set_ylim(0, max_values + 1)  # limit the y axis for fixed height
        ax.set_yticks(y_ticks)
        ax.set_yticklabels(entries)
        ax.invert_yaxis()
        ax.set_title(category["name"], loc="left")
    fig.tight_layout()

This will keep the bar height the same (at least across the figure) no matter how many entries have a certain category, thanks to the y limit (set_ylim()) set to the highest number of bars across the data. However, it will also leave nasty gaps in categories that have less than max number of entries. Or to put everything in a visual perspective, I'm trying to get it from Actual to Expected:

IMG LINK

I've tried removing the gaps through gridspec and different scales in dependence of number of entries but that only ended up looking even more 'skewed' and inconsistent. I tried generating multiple charts and manipulating the figure size then stitching them together in post-process but I couldn't find a way to reliably have the bar height remain the same no matter the data. I'm certain there is a way to extract the needed metrics for precise scaling from some obscure object in matplotlib but at this point I'm afraid I'll go on another wild-goose chase if I try to trace the whole drawing procedure.

Furthermore, if individual subplots can be collapsed around the data, how could I make the figure grow based on the data? For example, if I were to add a fourth category to the above data instead of having the figure 'grow' in height by another chart, it will actually shrink all the charts to fit everything on the default figure size. Now, I think I understand the logic behind matplotlib with axis units and all that, and I know I can set the figure size to increase the overall height but I've no idea how to keep it consistent across the charts, namely how to have the bar height exactly the same regardless of the data?

Do I really need to plot everything manually to get what I want? If so, I might just dump the whole matplotlib package and create my own SVGs from scratch. With hindsight, given the amount of time I've spent on this, I probably should've done that in the first place but now I'm way too stubborn to give it up (or I am a victim of the dreaded sunk cost fallacy).

Any ideas?

Thanks

Upvotes: 6

Views: 6234

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339775

I think the only way to have at the same time equal bar width (width in vertical direction) and differing subplotsizes is really to manually position the axes in the figure.

To this end you can specify how large in inches a bar should be and how much spacing you want to have between the subplots in units of this bar width. From those numbers together with the amount of data to plot, you can calculate the total figure height in inches. Each of the subplots is then positioned (via fig.add_axes) according to the amount of data and the spacing in the previous subplots. Thereby you nicely fill up the plot. Adding a new set of data will then make the figure larger.

data = [
    {"name": "Category 1", "entries": [
        {"name": "Entry 1", "value": 5},
        {"name": "Entry 2", "value": 2},
    ]},
    {"name": "Category 2", "entries": [
        {"name": "Entry 1", "value": 1},
    ]},
    {"name": "Category 3", "entries": [
        {"name": "Entry 1", "value": 1},
        {"name": "Entry 2", "value": 10},
        {"name": "Entry 3", "value": 4},
    ]}, 
    {"name": "Category 4", "entries": [
        {"name": "Entry 1", "value": 6},
    ]},
]

import matplotlib.pyplot as plt
import numpy as np

def plot_data(data,
              barwidth = 0.2, # inch per bar
              spacing = 3,    # spacing between subplots in units of barwidth
              figx = 5,       # figure width in inch
              left = 4,       # left margin in units of bar width
              right=2):       # right margin in units of bar width

    tc = len(data) # "total_categories", holds how many charts to create
    max_values = []  # holds the maximum number of bars to create
    for category in data:
        max_values.append( len(category["entries"]))
    max_values = np.array(max_values)

    # total figure height:
    figy = ((np.sum(max_values)+tc) + (tc+1)*spacing)*barwidth #inch

    fig = plt.figure(figsize=(figx,figy))
    ax = None
    for index, category in enumerate(data):
        entries = []
        values = []
        for entry in category["entries"]:
            entries.append(entry["name"])
            values.append(entry["value"])
        if not entries:
            continue  # do not create empty charts
        y_ticks = range(1, len(entries) + 1)
        # coordinates of new axes [left, bottom, width, height]
        coord = [left*barwidth/figx, 
                 1-barwidth*((index+1)*spacing+np.sum(max_values[:index+1])+index+1)/figy,  
                 1-(left+right)*barwidth/figx,  
                 (max_values[index]+1)*barwidth/figy ] 

        ax = fig.add_axes(coord, sharex=ax)
        ax.barh(y_ticks, values)
        ax.set_ylim(0, max_values[index] + 1)  # limit the y axis for fixed height
        ax.set_yticks(y_ticks)
        ax.set_yticklabels(entries)
        ax.invert_yaxis()
        ax.set_title(category["name"], loc="left")


plot_data(data)
plt.savefig(__file__+".png")
plt.show()

enter image description here

Upvotes: 5

Related Questions