Prakhar Sharma
Prakhar Sharma

Reputation: 758

matplotlib legend not showing correctly

I am trying to plot some data from a csv file. I used the Pandas to load the csv file. I am using sns.lineplot() to plot the lines. But one of the legend is always faulty. It shows a square around one of the legend.

plt.figure(dpi=150)
lin1 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="tanh"])
lin2 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="silu"])
lin3 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="swish"])
plt.xlabel("Training time  (sec)")
plt.legend(("tanh", "silu", "swish"))
plt.yscale('log',base=10)

I used 3 different functions because there are more Activations. This is the resulting plot.

enter image description here

The plot is looking correct but the legend is creating problems. Here are versions of the plotting tools that I am using.

Python 3.9.12
matplotlib                    3.6.1
matplotlib-inline             0.1.6
seaborn                       0.12.1

I could not find the same issue on Internet. A kernel restart isn't helping. Please let me know if more information is needed.

Upvotes: 0

Views: 276

Answers (2)

Tranbi
Tranbi

Reputation: 12701

You can also plot all your lines with a single command by using hue:

sns.lineplot(x="Training time", y="Relative L2 error", data=df[df["Activation"].isin(["tanh", "silu", "swish"])], hue="Activation")

Edit: as @JohanC cleverly suggested, you could use hue_order to get a slightly more compact expression:

sns.lineplot(x="Training time", y="Relative L2 error", data=df, hue="Activation", hue_order=["tanh", "silu", "swish"])

Upvotes: 3

Michael Cao
Michael Cao

Reputation: 3609

Try adding the label argument to the individual lineplots and then just call legend without any arguments:

plt.figure(dpi=150)
lin1 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="tanh"], label = 'tanh')
lin2 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="silu"], label = 'silu')
lin3 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="swish"], label = 'swish')
plt.xlabel("Training time  (sec)")
plt.legend()
plt.yscale('log',base=10)

Upvotes: 1

Related Questions