math
math

Reputation: 2022

How to create correct legend entries in seaborn for line plots of different styles and a scatter plot

I have a question how to combine three plots with different linestyle and adjust the legend accordingly. I have two line charts consisting each of two lines. One is solid lined the other dashed. On top I'm adding a scatter plot.

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
import numpy as np
#sample data generation
plot_df = pd.DataFrame(index=np.arange(5), columns=["Series 1", "Series 2"], data=np.array([[1, 2],[2.4, 5],[4.1, 7.1],[5, 8.9],[5.2, 10]]))
plot_df_dash = pd.DataFrame(index=np.arange(5), columns=["Series 1 dashed", "Series 2 dashed "], data=np.array([[2, 3],[3.4, 4],[5.1, 6.1],[7, 1.9],[4.2, 12]]))
plot_df_points = pd.DataFrame(index = [1.5, 2, 3.7], columns = ["Series 1", "Series 2"], data=np.array([[1.2, 3.4],[4.5, 6.9],[5.5, 9.6]]))
df = pd.DataFrame(plot_df.stack()).reset_index()
df_dash = pd.DataFrame(plot_df.stack()).reset_index()
df.columns = ["x", "Series","y"]
df_dash.columns=["x", "Series dashed","y"]
df_points = pd.DataFrame(plot_df_points.stack()).reset_index()
df_points.columns = ["x", "Series","y"]

#plotting
fig, ax = plt.subplots()
sns.lineplot(data=df,x="x",y="y", hue="Series",ax=ax,palette="rocket",linewidth=2.5)
sns.lineplot(data=df_dash,x="x",y="y", hue="Series dashed",ax=ax,palette="rocket",linewidth=2.5,linestyle="--")
sns.scatterplot(data=df_points, x="x", y="y", hue="Series", ax=ax,s=200)

#generating legend
handles, labels = ax.get_legend_handles_labels()
ax.legend([tuple(handles[::2]), tuple(handles[1::2])], labels[:4], handlelength=3,
          handler_map={tuple: HandlerTuple(ndivide=None)})

plt.show()
plt.close()

I'm struggling to get the right legend. The scatter plot should be used across the dashed and solid line plot. That means the legend should show 4 entries, two solid lines with circles from the scatter plot plus two dashed lines with circles from the same scatter plot.

Upvotes: 2

Views: 3126

Answers (1)

Mr. T
Mr. T

Reputation: 12410

As I said, one way is to manually set the correct line properties in the legend handles.

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
import numpy as np
#sample data generation
plot_df = pd.DataFrame(index=np.arange(5), columns=["Series 1", "Series 2"], data=np.array([[1, 2],[2.4, 5],[4.1, 7.1],[5, 8.9],[5.2, 10]]))
plot_df_dash = pd.DataFrame(index=np.arange(5), columns=["Series 1 dashed", "Series 2 dashed "], data=np.array([[2, 3],[3.4, 4],[5.1, 6.1],[7, 1.9],[4.2, 12]]))
plot_df_points = pd.DataFrame(index = [1.5, 2, 3.7], columns = ["Series 1", "Series 2"], data=np.array([[1.2, 3.4],[4.5, 6.9],[5.5, 9.6]]))
df = pd.DataFrame(plot_df.stack()).reset_index()
#changed the dataframe generation here - the reason why you did not see dashed lines
df_dash = pd.DataFrame(plot_df_dash.stack()).reset_index()
df.columns = ["x", "Series","y"]
df_dash.columns=["x", "Series dashed","y"]
df_points = pd.DataFrame(plot_df_points.stack()).reset_index()
df_points.columns = ["x", "Series","y"]

#plotting
fig, ax = plt.subplots()
#same color palette for the series
sns.color_palette("rocket")
#defining linestyle and width for dashed line
ls = "--"
lw = 3.5
sns.lineplot(data=df, x="x", y="y", hue="Series", ax=ax, linewidth=2.5)
sns.lineplot(data=df_dash, x="x", y="y", hue="Series dashed", ax=ax, linewidth=lw, linestyle=ls)
sns.scatterplot(data=df_points, x="x", y="y", hue="Series", ax=ax, s=200)

#generating legend
handles, labels = ax.get_legend_handles_labels()
#manipulating appearance of wrongly generated seaborn line2D objects for dashed lines
for i in [2, 3]:
    handles[i].set_linestyle(ls) 
    handles[i].set_linewidth(lw) 
#generate legend entries as suggested by you
ax.legend([tuple([handles[0], handles[4]]), 
           tuple([handles[1], handles[5]]), 
           tuple([handles[2], handles[4]]), 
           tuple([handles[3], handles[5]])], 
           labels[:4], handlelength=7, 
           handler_map={tuple: HandlerTuple(ndivide=None)})

plt.show()

Sample output: enter image description here

As an aside, it turned out you didn't see the dashed lines because you wrongly attributed df values to df_dash.

If you only wanted to generate series entries as line properties are often explained in figure legends, the code would simplify to:

#generating legend
handles, labels = ax.get_legend_handles_labels()
for i in [2, 3]:
    handles[i].set_linestyle(ls) 
    handles[i].set_linewidth(lw)         

ax.legend([tuple(handles[::2]), tuple(handles[1::2])], labels[:2], handlelength=10,
          handler_map={tuple: HandlerTuple(ndivide=None)})

plt.show()

enter image description here

Upvotes: 2

Related Questions