Daniel
Daniel

Reputation: 5381

Legend in matplotlib shows first entry of a list only

I'm trying to display a custom legend for a bar graph, but it is only displaying the first legend in the legend list. How can I display all the values in the legend?

df.time_to_travel_grouping.value_counts().plot(kind="bar", 
                                               color = ["b","tab:green","tab:red","c","m","y","tab:blue","tab:orange"],
                                               xlabel="TTT", ylabel="Total Counts", 
                                               title="Fig4: Total Counts by Time to Travel Category (TTT)", figsize=(20,15))
plt.legend(["a","b","c","d","e","f","g","h"])
plt.subplots_adjust(bottom=0.15)
plt.subplots_adjust(left=0.15)

enter image description here

Upvotes: 1

Views: 2130

Answers (3)

Celuk
Celuk

Reputation: 897

Just putting the strings in legend function does not work as you expected in matplotlib. So, for adding all desired legends to the plot, you can make the patch objects from them with colors and add by this way. This piece of code will do the job and I think more generalized than the other solutions:

## include this library
import matplotlib.patches as mpatches

## desired legends
legend_list = ["a","b","c","d","e","f","g","h"]
## corresponding colors in the same order
color_list = ["b","tab:green","tab:red","c","m","y","tab:blue","tab:orange"]

## make patches from the legends and corresponding colors
patch_list = []
i = 0
for each_legend in legend_list:
    patch_list.append(mpatches.Patch(label=each_legend, color=color_list[i]))
    i += 1

## add made patches to the plot
plt.legend(handles=patch_list, fontsize=12, loc=(1, 0))

Upvotes: 0

JohanC
JohanC

Reputation: 80329

To create an automatic legend, matplotlib stores labels for graphical elements. In the case of this bar plot, the complete 'container' pandas assigns one label to the complete 'container'.

You could remove the label of the container (assigning a label starting with _), and assign individual labels to the bars. The xtick labels can be used, as they are already in the desired order.

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

df = pd.DataFrame({'time_to_travel_grouping': np.random.choice([*'abcdefgh'], 200)})
ax = df.time_to_travel_grouping.value_counts().plot(kind="bar",
                                                    color=["b", "tab:green", "tab:red", "c", "m", "y", "tab:blue", "tab:orange"],
                                                    xlabel="TTT", ylabel="Total Counts",
                                                    title="Fig4: Total Counts by Time to Travel Category (TTT)",
                                                    figsize=(20, 15))

ax.containers[0].set_label('_nolegend')
for bar, tick_label in zip(ax.containers[0], ax.get_xticklabels()):
    bar.set_label(tick_label.get_text())
ax.legend()
plt.tight_layout()
plt.show()

pandas bar plot with individual bars in legend

With a little bit less internal manipulation, something similar can be obtained via seaborn:


import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

df = pd.DataFrame({'time_to_travel_grouping': np.random.choice([*'abcdefgh'], 200)})

plt.figure(figsize=(20, 15))
ax = sns.countplot(data=df, x='time_to_travel_grouping', hue='time_to_travel_grouping',
                   palette=["b", "tab:green", "tab:red", "c", "m", "y", "tab:blue", "tab:orange"],
                   order=df.time_to_travel_grouping.value_counts().index,
                   dodge=False)
plt.setp(ax, xlabel="TTT", ylabel="Total Counts", title="Fig4: Total Counts by Time to Travel Category (TTT)")
plt.tight_layout()
plt.show()

sns.countplot

Upvotes: 1

Scott Boston
Scott Boston

Reputation: 153460

Let's get the patches handles from the axes using ax.get_legend_handles_labels:

s = pd.Series(np.arange(100,50,-5), index=[*'abcdefghij'])
ax = s.plot(kind="bar", 
           color = ["b","tab:green","tab:red","c","m","y","tab:blue","tab:orange"],
           xlabel="TTT", ylabel="Total Counts", 
           title="Fig4: Total Counts by Time to Travel Category (TTT)", figsize=(20,15))


patches, _ = ax.get_legend_handles_labels()
labels = [*'abcdefghij']
ax.legend(*patches, labels, loc='best')

plt.subplots_adjust(bottom=0.15)
plt.subplots_adjust(left=0.15)

Output:

enter image description here

Upvotes: 2

Related Questions