TheStrangeQuark
TheStrangeQuark

Reputation: 2405

Showing legend under matplotlib plot with varying number of plots

I'm working on a program that allows users to put a different number of plots onto an axis. I need to show the legend, and the plot labels are long so I think it is best to show this under the plot. When I've done this before, I've always just shrunk the plot a bit and put the legend under the plot. Now that there are a varying number of plots this isn't as simple since I cannot find a nice formula for determining how much to shrink the plot and how far down to put the legend so it is not being cut off or overlapping the axis.

I've written a an example code to demonstrate what I currently am doing, which is ugly. I currently am checking how many items are in the plot and tried to manually optimize the axis shrink and legend offset parameters then did a big if loop to use the manually optimized values. They are not optimized for this example code, but I think it demonstrates what I am doing.

import matplotlib.pyplot as plt
import numpy as np

def find_scales(legendData):
        leg_len = len(legendData)
        if leg_len == 0:
            height_scale = 1
            legend_offset = 0
        elif leg_len == 1:
            height_scale = 0.96
            legend_offset = -0.18
        elif leg_len == 2:
            height_scale = 0.95
            legend_offset = -0.25
        elif leg_len == 3:
            height_scale = 0.94
            legend_offset = -0.35
        elif leg_len == 4:
            height_scale = 0.93
            legend_offset = -0.45
        elif leg_len == 5:
            height_scale = 0.93
            legend_offset = -0.57
        elif leg_len == 6:
            height_scale = 0.93
            legend_offset = -0.68
        elif leg_len == 7:
            height_scale = 0.93
            legend_offset = -0.82
        elif leg_len == 8:
            height_scale = 0.93
            legend_offset = -0.98
        elif leg_len == 9:
            height_scale = 0.92
            legend_offset = -1.3
        elif leg_len == 10:
            height_scale = 0.92
            legend_offset = -1.53
        else:
            height_scale = 0.92
            legend_offset = -1.8
        return height_scale, legend_offset

num_plots = 3
x_range = np.arange(10)

fig,ax = plt.subplots()

for i in range(num_plots):
    ax.plot(x_range, np.random.rand(10))

legend_labels = ['a','b','c','d','e','f','g','h','i','j'][:num_plots]

box = ax.get_position()

height_scale, legend_offset = find_scales(legend_labels)

ax.set_position([box.x0, box.y0 + box.height * (1-height_scale), #left, bottom, width, height
             box.width, box.height * height_scale])

ax.legend(legend_labels, loc=3, bbox_to_anchor=(0,legend_offset), borderaxespad=0.)

plt.show()

I'm hoping there's a better way to do this. I want the legend to be under the axis. I cannot have the legend be overlapping the axis or x-label. I cannot have the legend being cut off by being too low and out of the figure. Is there a way to do this so the axis and legend will automatically size themselves to fit in the figure?

Upvotes: 0

Views: 1003

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339102

A way to correct the axes position, such that the legend has enough space is shown in this question's answer: Creating figure with exact size and no padding (and legend outside the axes)

For the legend to sit on the bottom the solution is much simpler. Essentially you only need to subtract the legend height from the axes height and move the axes by the amount of the legend height towards the top.

import matplotlib.pyplot as plt 

fig = plt.figure(figsize = [3.5,2]) 
ax = fig.add_subplot(111)
ax.set_title('title')
ax.set_ylabel('y label')
ax.set_xlabel('x label')
ax.plot([1,2,3], marker="o", label="quantity 1")
ax.plot([2,1.7,1.2], marker="s", label="quantity 2")

def legend(ax, x0=0.5,y0=0, pad=0.5,**kwargs):
    otrans = ax.figure.transFigure
    t = ax.legend(bbox_to_anchor=(x0,y0), loc=8, bbox_transform=otrans,**kwargs)
    ax.figure.tight_layout(pad=pad)
    ax.figure.canvas.draw()
    tbox = t.get_window_extent().transformed( ax.figure.transFigure.inverted() )
    bbox = ax.get_position()
    ax.set_position([bbox.x0, bbox.y0+tbox.height,bbox.width, bbox.height-tbox.height]) 

legend(ax,y0=0, borderaxespad=0.2)

plt.savefig(__file__+'.pdf')
plt.show()

enter image description here

Upvotes: 1

Related Questions