Reputation: 1353
I am training a neural network with different hyper-parameters and would like to plot the different results in order to compare which ones perform better.
I currently have a plugging to do this but would like to do it myself with matplotlib
. I would like to replicate the following image.
That is, I want to plot "Activation" VS "Accuracy".
I did the following code.
ss = df.groupby("Activation")["Accuracy"]
ss.plot(x='Activation', y='Accuracy', style='.')
I thought I should use groupby()
for this, but I can't get anything and it's not clear to me if this is the best way to do it. How do I group all the data of a column (Accuracy) with respect to another column (Activation)? Is better plotting it with Pandas
or Matplotlib
?
A sample of my dataframe,
Sensor,Architecture,Batch Size,Epochs,Activation,Optimizer,Loss,Accuracy,No No,Fr No,Si No,No Fr,Fr Fr,Si Fr,No Si,Fr Si,Si Si
Sens1,LeNet,16,100,tanh,Adam,0.9682227969169616,0.4455208480358124,2469,1486,45,1794,1918,288,540,2515,945
Sens2,LeNet,32,100,tanh,Adam,0.9768306612968444,0.441937506198883,895,572,2533,142,397,3461,0,0,4000
Sens3,LeNet,32,100,tanh,Adam,1.0033334493637085,0.4466041624546051,972,1981,1047,435,2527,1038,0,2182,1818
Sens1,LeNet,32,100,tanh,Adam,1.0002760887145996,0.4468958377838135,1048,248,2704,446,300,3254,0,0,4000
Sens3,LeNet,16,1,relu,Adam,0.991603136062622,0.4590624868869781,1042,168,2790,379,437,3184,0,43,3957
Sens1,LeNet,16,1,relu,Adam,0.9216567277908324,0.5240625143051147,1548,848,1604,710,968,2322,1,299,3700
Sens2,LeNet,16,1,relu,Adam,0.9375953674316406,0.5098333358764648,1780,361,1859,639,727,2634,24,306,3670
Sens3,LeNet,32,1,relu,Adam,0.9913602471351624,0.481187492609024,1007,336,2657,403,845,2752,0,169,3831
I tried other things like this, but it's obviously wrong...
activation = df['Activation'].values
accuracy = df['Accuracy'].values
eje_x = ('tanh','relu')
tanh, relu = [], []
for act,acc in zip(activation, accuracy):
if act == 'tanh':
tanh.append(acc)
else:
relu.append(acc)
plt.plot(tanh, label='tanh', linestyle='dashed', marker='.')
plt.plot(relu, label='relu', linestyle='dashed', marker='.')
plt.legend(loc='best')
plt.show()
Thanks.
Upvotes: 1
Views: 181
Reputation: 41327
seaborn.stripplot
with jitter
disabled
import seaborn as sns
sns.stripplot(data=df, x='Activation', y='Accuracy', jitter=False)
df.plot.scatter(x='Activation', y='Accuracy')
plt.scatter(df.Activation, df.Accuracy)
ax.set_xlabel('Activation')
ax.set_ylabel('Accuracy')
or with a data
source:
plt.scatter(data=df, x='Activation', y='Accuracy')
Upvotes: 1