Faizy
Faizy

Reputation: 140

callbacks in keras gives KeyError: 'metrics'?

callbacks gives KeyError: 'metrics' while training in Colab

DATASET: SETI

pip install livelossplot
from livelossplot.tf_keras import PlotLossesCallback
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from sklearn.metrics import confusion_matrix
from sklearn import metrics

import numpy as np
np.random.seed(42)
import warnings;warnings.simplefilter('ignore')
%matplotlib inline
print('Tensorflow version:', tf.__version__)

. . . .

model.compile(optimizer = optimizer, loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.summary()

checkpoint = ModelCheckpoint("model_weights.h5", monitor='val_loss',
                             save_weights_only=True, mode='min', verbose=0)

my_callbacks = [PlotLossesCallback(), checkpoint]#, reduce_lr]

batch_size = 32
history = model.fit(
    datagen_train.flow(x_train, y_train, batch_size=batch_size, shuffle=True),
    steps_per_epoch=len(x_train)//batch_size,
    validation_data = datagen_val.flow(x_val, y_val, batch_size=batch_size, shuffle=True),
    validation_steps = len(x_val)//batch_size,
    epochs=50,
    callbacks=my_callbacks
)

Error

KeyError                                  Traceback (most recent call last)
<ipython-input-60-ff0dc86d079d> in <module>()
     11     validation_steps = len(x_val)//batch_size,
     12     epochs=12,
---> 13     callbacks=callbacks
     14 )

3 frames
/usr/local/lib/python3.6/dist-packages/livelossplot/generic_keras.py in on_train_begin(self, logs)
     29 
     30     def on_train_begin(self, logs={}):
---> 31         self.liveplot.set_metrics([metric for metric in self.params['metrics'] if not metric.startswith('val_')])
     32 
     33         # slightly convolved due to model.complie(loss=...) stuff

KeyError: 'metrics'

Upvotes: 0

Views: 1699

Answers (1)

event_horizon
event_horizon

Reputation: 46

Your import is using the older API, there have been some API changes in newer versions

Just change your import statement

from livelossplot.tf_keras import PlotLossesCallback

to

from livelossplot.inputs.tf_keras import PlotLossesCallback

Checkout the livelossplot github for more info and examples: livelossplot-github

Upvotes: 2

Related Questions