gibbidi
gibbidi

Reputation: 183

How to record val_loss and loss per batch in keras

I'm using the callback function in keras to record the loss and val_loss per epoch, But I would like to a do the same but per batch. I found a callback function called on_batch_begin(self,batch,log={}), but I not sure how to use it.

Upvotes: 18

Views: 9028

Answers (4)

Yusuf Rıdvan SAVUT
Yusuf Rıdvan SAVUT

Reputation: 51

I have an answer. I was able to calculate it. I calculated manuel LossFunction(Binary Cross Entropy). Because, if i use "model.evaluated(x_test, y_test)" then the training phase takes quite a long time.

But at the end of each batch, when keras calculates and performs with BCE, the process is much faster

 class LossHistory(keras.callbacks.Callback):
        
        def on_train_begin(self, logs={}):
            self.losses = []
            self.val_losses= []
    
        def on_batch_end(self, batch, logs={}):
            y_pred = self.model.predict(x_test_scaled)
            bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)
            val_loss = bce(y_test, y_pred).numpy()
            self.val_losses.append(val_loss)
            self.losses.append(logs.get('loss'))

It remains only to fit the model.

hist = LossHistory()

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.binary_crossentropy,
              metrics=[keras.metrics.binary_crossentropy])

history = model.fit(x_train_scaled, y_train,batch_size=256, epochs=15,
          verbose=1,callbacks=[hist], validation_data=(x_test_scaled, y_test))

And end of the model fit, you can show validation loss for every batch step.

hist.val_losses

Upvotes: 0

Marcin Możejko
Marcin Możejko

Reputation: 40516

Here is an example of custom callback. Following and modifying an example from here:

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.val_losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))

model = Sequential()
model.add(Dense(10, input_dim=784, init='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

history = LossHistory()
model.fit(X_train, Y_train, batch_size=128, nb_epoch=20, verbose=0, validation_split=0.1,
          callbacks=[history])

print history.losses
# outputs
'''
[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]
'''
print history.val_losses

Upvotes: 16

pain.reign
pain.reign

Reputation: 371

From tf.keras documentation it seems that on_batch_end doesn't include val_loss

on_batch_end: logs include loss, and optionally acc (if accuracy monitoring is enabled).

as mentioned here: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback

Upvotes: 3

Gerardo Muñoz
Gerardo Muñoz

Reputation: 41

import numpy as np
import matplotlib.pyplot as plt
import keras

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.history = {'loss':[],'val_loss':[]}

    def on_batch_end(self, batch, logs={}):
        self.history['loss'].append(logs.get('loss'))

    def on_epoch_end(self, epoch, logs={}):
        self.history['val_loss'].append(logs.get('val_loss'))

history = LossHistory()

model = keras.Sequential()
model.add(keras.layers.Dense(32, activation='relu', input_dim=100))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy')

# Generate dummy data
import numpy as np
data = np.random.random((1000, 100))
labels = np.random.randint(2, size=(1000, 1))

# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32, 
          validation_split=0.2, callbacks=[history])

# Plot the history
y1=history.history['loss']
y2=history.history['val_loss']
x1 = np.arange( len(y1))
k=len(y1)/len(y2)
x2 = np.arange(k,len(y1)+1,k)
fig, ax = plt.subplots()
line1, = ax.plot(x1, y1, label='loss')
line2, = ax.plot(x2, y2, label='val_loss')
plt.show()

Upvotes: 4

Related Questions