Reputation: 25829
After struggling with matplotlib for longer than I'd like to admit by trying to do something that's a breeze in pretty much any other plotting library I ever used, I've decided to ask the Stackiverse for some insight. In a nutshell, what I need is to create multiple horizontal bar charts, all sharing the x axis, with different number of values on the y axis and with all the bars having the same height, while the charts themselves would adjust to the amount of entries. A simplified data structure of what I need to plot would be something like:
[
{"name": "Category 1", "entries": [
{"name": "Entry 1", "value": 5},
{"name": "Entry 2", "value": 2},
]},
{"name": "Category 2", "entries": [
{"name": "Entry 1", "value": 1},
]},
{"name": "Category 3", "entries": [
{"name": "Entry 1", "value": 1},
{"name": "Entry 2", "value": 10},
{"name": "Entry 3", "value": 4},
]},
]
And the closesest I got to what I'd like as a result is using:
import matplotlib.pyplot as plt
def plot_data(data):
total_categories = len(data) # holds how many charts to create
max_values = 1 # holds the maximum number of bars to create
for category in data:
max_values = max(max_values, len(category["entries"]))
fig = plt.figure(1)
ax = None
for index, category in enumerate(data):
entries = []
values = []
for entry in category["entries"]:
entries.append(entry["name"])
values.append(entry["value"])
if not entries:
continue # do not create empty charts
y_ticks = range(1, len(entries) + 1)
ax = fig.add_subplot(total_categories, 1, index + 1, sharex=ax)
ax.barh(y_ticks, values)
ax.set_ylim(0, max_values + 1) # limit the y axis for fixed height
ax.set_yticks(y_ticks)
ax.set_yticklabels(entries)
ax.invert_yaxis()
ax.set_title(category["name"], loc="left")
fig.tight_layout()
This will keep the bar height the same (at least across the figure) no matter how many entries have a certain category, thanks to the y limit (set_ylim()) set to the highest number of bars across the data. However, it will also leave nasty gaps in categories that have less than max number of entries. Or to put everything in a visual perspective, I'm trying to get it from Actual to Expected:
I've tried removing the gaps through gridspec and different scales in dependence of number of entries but that only ended up looking even more 'skewed' and inconsistent. I tried generating multiple charts and manipulating the figure size then stitching them together in post-process but I couldn't find a way to reliably have the bar height remain the same no matter the data. I'm certain there is a way to extract the needed metrics for precise scaling from some obscure object in matplotlib but at this point I'm afraid I'll go on another wild-goose chase if I try to trace the whole drawing procedure.
Furthermore, if individual subplots can be collapsed around the data, how could I make the figure grow based on the data? For example, if I were to add a fourth category to the above data instead of having the figure 'grow' in height by another chart, it will actually shrink all the charts to fit everything on the default figure size. Now, I think I understand the logic behind matplotlib with axis units and all that, and I know I can set the figure size to increase the overall height but I've no idea how to keep it consistent across the charts, namely how to have the bar height exactly the same regardless of the data?
Do I really need to plot everything manually to get what I want? If so, I might just dump the whole matplotlib package and create my own SVGs from scratch. With hindsight, given the amount of time I've spent on this, I probably should've done that in the first place but now I'm way too stubborn to give it up (or I am a victim of the dreaded sunk cost fallacy).
Any ideas?
Thanks
Upvotes: 6
Views: 6234
Reputation: 339775
I think the only way to have at the same time equal bar width (width in vertical direction) and differing subplotsizes is really to manually position the axes in the figure.
To this end you can specify how large in inches a bar should be and how much spacing you want to have between the subplots in units of this bar width. From those numbers together with the amount of data to plot, you can calculate the total figure height in inches.
Each of the subplots is then positioned (via fig.add_axes
) according to the amount of data and the spacing in the previous subplots. Thereby you nicely fill up the plot.
Adding a new set of data will then make the figure larger.
data = [
{"name": "Category 1", "entries": [
{"name": "Entry 1", "value": 5},
{"name": "Entry 2", "value": 2},
]},
{"name": "Category 2", "entries": [
{"name": "Entry 1", "value": 1},
]},
{"name": "Category 3", "entries": [
{"name": "Entry 1", "value": 1},
{"name": "Entry 2", "value": 10},
{"name": "Entry 3", "value": 4},
]},
{"name": "Category 4", "entries": [
{"name": "Entry 1", "value": 6},
]},
]
import matplotlib.pyplot as plt
import numpy as np
def plot_data(data,
barwidth = 0.2, # inch per bar
spacing = 3, # spacing between subplots in units of barwidth
figx = 5, # figure width in inch
left = 4, # left margin in units of bar width
right=2): # right margin in units of bar width
tc = len(data) # "total_categories", holds how many charts to create
max_values = [] # holds the maximum number of bars to create
for category in data:
max_values.append( len(category["entries"]))
max_values = np.array(max_values)
# total figure height:
figy = ((np.sum(max_values)+tc) + (tc+1)*spacing)*barwidth #inch
fig = plt.figure(figsize=(figx,figy))
ax = None
for index, category in enumerate(data):
entries = []
values = []
for entry in category["entries"]:
entries.append(entry["name"])
values.append(entry["value"])
if not entries:
continue # do not create empty charts
y_ticks = range(1, len(entries) + 1)
# coordinates of new axes [left, bottom, width, height]
coord = [left*barwidth/figx,
1-barwidth*((index+1)*spacing+np.sum(max_values[:index+1])+index+1)/figy,
1-(left+right)*barwidth/figx,
(max_values[index]+1)*barwidth/figy ]
ax = fig.add_axes(coord, sharex=ax)
ax.barh(y_ticks, values)
ax.set_ylim(0, max_values[index] + 1) # limit the y axis for fixed height
ax.set_yticks(y_ticks)
ax.set_yticklabels(entries)
ax.invert_yaxis()
ax.set_title(category["name"], loc="left")
plot_data(data)
plt.savefig(__file__+".png")
plt.show()
Upvotes: 5