Reputation: 4483
I want to combine these two plots together into one. What is the recommended approach to do this? Is there a way to do it using a single dataframe?
sns.relplot(x="epoch", y="loss", kind="line", color='orange', ci="sd", data=losses_df);
sns.relplot(x="epoch", y="loss", kind="line", color='red', ci="sd", data=val_losses_df);
My data for first is the following, with columns in this order ['epoch'], ['loss']
losses_df
0.0 0.156077
0.0 0.013558
0.0 0.007013
1.0 0.029891
1.0 0.008320
1.0 0.003487
2.0 0.017474
2.0 0.006232
2.0 0.002457
3.0 0.013332
3.0 0.004897
3.0 0.001900
4.0 0.010947
4.0 0.003905
4.0 0.001594
5.0 0.009127
5.0 0.003195
5.0 0.001341
6.0 0.007751
6.0 0.002681
6.0 0.001157
7.0 0.006605
7.0 0.002218
7.0 0.000972
8.0 0.005630
8.0 0.001867
8.0 0.000832
9.0 0.004839
9.0 0.001671
9.0 0.000748
val_losses_df
0.0 0.048945
0.0 0.006090
0.0 0.002332
1.0 0.024670
1.0 0.006243
1.0 0.002337
2.0 0.022344
2.0 0.006609
2.0 0.002626
3.0 0.022037
3.0 0.007156
3.0 0.003080
4.0 0.022025
4.0 0.008209
4.0 0.003835
5.0 0.022751
5.0 0.009226
5.0 0.004209
6.0 0.024093
6.0 0.009950
6.0 0.004783
7.0 0.025410
7.0 0.011130
7.0 0.005279
8.0 0.028299
8.0 0.012204
8.0 0.005969
9.0 0.028623
9.0 0.013037
9.0 0.006519
And my plots so far (except I want them combined in one plot with a legend)
Upvotes: 0
Views: 2119
Reputation: 80329
You could combine the two dataframes to one:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
# losses_df = pd.read_csv(...)
# val_losses_df = pd.read_csv(...)
losses_df['type'] = 'losses'
val_losses_df['type'] = 'val_losses'
combined_df = pd.concat([losses_df, val_losses_df])
sns.relplot(x="epoch", y="loss", kind="line", hue="type", palette=['orange', 'red'], ci="sd", data=combined_df)
plt.tight_layout()
plt.show()
Upvotes: 2
Reputation: 40697
relplot()
is a Figure level function that creates a new Figure at each call.
Use lineplot()
instead.
fig, ax = plt.subplots()
sns.lineplot(x="epoch", y="loss", color='orange', ci="sd", data=losses_df, ax=ax);
sns.lineplot(x="epoch", y="loss", color='red', ci="sd", data=val_losses_df, ax=ax);
Upvotes: 1