RGWinston
RGWinston

Reputation: 415

Adding a legend to a matplotlib plot with a multicolored line

Following the example on how to draw multicolored lines I can draw lines that change color along their length based on some color map. Trying to add a legend to the plot I added this code:

plt.legend([lc], ["test"],\
    handler_map={lc: matplotlib.legend_handler.HandlerLineCollection()})

This adds a legend to the plot (figure below) but the color of the icon in the legend does not relate at all to the colors of the line. Is this the wrong way to try to add a legend to this plot, or is this a limitation of matplotlib?

attempt at multicolored line with legend

Upvotes: 2

Views: 2738

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339660

The idea would be to show a line collection in the legend as well. There is no inbuilt way to do that but one may subclass HandlerLineCollection and create the respective LineCollection within its create_artists method.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerLineCollection
from matplotlib.collections import LineCollection

class HandlerColorLineCollection(HandlerLineCollection):
    def create_artists(self, legend, artist ,xdescent, ydescent,
                        width, height, fontsize,trans):
        x = np.linspace(0,width,self.get_numpoints(legend)+1)
        y = np.zeros(self.get_numpoints(legend)+1)+height/2.-ydescent
        points = np.array([x, y]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        lc = LineCollection(segments, cmap=artist.cmap,
                     transform=trans)
        lc.set_array(x)
        lc.set_linewidth(artist.get_linewidth())
        return [lc]

t = np.linspace(0, 10, 200)
x = np.cos(np.pi * t)
y = np.sin(t)
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

lc = LineCollection(segments, cmap=plt.get_cmap('copper'),
                    norm=plt.Normalize(0, 10), linewidth=3)
lc.set_array(t)

fig, ax = plt.subplots()
ax.add_collection(lc)

plt.legend([lc], ["test"],\
    handler_map={lc: HandlerColorLineCollection(numpoints=4)}, framealpha=1)

ax.autoscale_view()
plt.show()

enter image description here

Upvotes: 10

Related Questions