erotavlas
erotavlas

Reputation: 4483

How do I combine these two line plots together using seaborn?

I want to combine these two plots together into one. What is the recommended approach to do this? Is there a way to do it using a single dataframe?

sns.relplot(x="epoch", y="loss", kind="line", color='orange', ci="sd", data=losses_df);
sns.relplot(x="epoch", y="loss", kind="line", color='red', ci="sd", data=val_losses_df);

My data for first is the following, with columns in this order ['epoch'], ['loss']

losses_df

0.0 0.156077
0.0 0.013558
0.0 0.007013
1.0 0.029891
1.0 0.008320
1.0 0.003487
2.0 0.017474
2.0 0.006232
2.0 0.002457
3.0 0.013332
3.0 0.004897
3.0 0.001900
4.0 0.010947
4.0 0.003905
4.0 0.001594
5.0 0.009127
5.0 0.003195
5.0 0.001341
6.0 0.007751
6.0 0.002681
6.0 0.001157
7.0 0.006605
7.0 0.002218
7.0 0.000972
8.0 0.005630
8.0 0.001867
8.0 0.000832
9.0 0.004839
9.0 0.001671
9.0 0.000748

val_losses_df

0.0 0.048945
0.0 0.006090
0.0 0.002332
1.0 0.024670
1.0 0.006243
1.0 0.002337
2.0 0.022344
2.0 0.006609
2.0 0.002626
3.0 0.022037
3.0 0.007156
3.0 0.003080
4.0 0.022025
4.0 0.008209
4.0 0.003835
5.0 0.022751
5.0 0.009226
5.0 0.004209
6.0 0.024093
6.0 0.009950
6.0 0.004783
7.0 0.025410
7.0 0.011130
7.0 0.005279
8.0 0.028299
8.0 0.012204
8.0 0.005969
9.0 0.028623
9.0 0.013037
9.0 0.006519

And my plots so far (except I want them combined in one plot with a legend)

enter image description here

Upvotes: 0

Views: 2119

Answers (2)

JohanC
JohanC

Reputation: 80329

You could combine the two dataframes to one:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# losses_df = pd.read_csv(...)
# val_losses_df = pd.read_csv(...)

losses_df['type'] = 'losses'
val_losses_df['type'] = 'val_losses'

combined_df = pd.concat([losses_df, val_losses_df])

sns.relplot(x="epoch", y="loss", kind="line", hue="type", palette=['orange', 'red'], ci="sd", data=combined_df)

plt.tight_layout()
plt.show()

example plot

Upvotes: 2

Diziet Asahi
Diziet Asahi

Reputation: 40697

relplot() is a Figure level function that creates a new Figure at each call. Use lineplot() instead.

fig, ax = plt.subplots()
sns.lineplot(x="epoch", y="loss", color='orange', ci="sd", data=losses_df, ax=ax);
sns.lineplot(x="epoch", y="loss", color='red', ci="sd", data=val_losses_df, ax=ax);

Upvotes: 1

Related Questions