Antoine Trouve
Antoine Trouve

Reputation: 1273

How to access to the number of steps by epoch in a Keras Lambda Callback

I would like to know how to access the number of batch per epoch from inside a Keras lambda callback, that is, the value passed to the parameter steps_per_epoch of the model.fit function.

Below is my custom callback:

(I want to fill the ??????? in batch_per_epoch = ???????)

class MyBatchLogger(keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self._current_epoch = 0

    def on_epoch_begin(self, epoch, logs=None):
        self._current_epoch = epoch

    def on_epoch_end(self, epoch, logs=None):
        print("Epoch end", logs)

    def on_batch_end(self, batch, logs={}):
        batch_per_epoch = ???????
        acc = logs["acc"].item()
        loss = logs["loss"].item()
        mae = logs["mean_absolute_error"].item()
        ca = logs["categorical_accuracy"].item()

        print(json.dumps({
            "timestamp": datetime.now().isoformat(),
            "epoch": self._current_epoch,
            "batch": batch,
            "batchPerEpoch": batch_per_epoch,
            "accuracy": acc,
            "meanAbsoluteError": mae,
            "categoricalAccuracy": ca,
            "loss": loss
        }))

I'm using Keras 2.2.5 with Tensorflow 1.14.1 but I'm OK to update if necessary.

Upvotes: 3

Views: 1167

Answers (1)

Guido Fantini
Guido Fantini

Reputation: 56

the answer might come a bit late but I've spent a bit of time digging for it so maybe that's helpful anyway.

The information you need is here

self.params.get('steps')

Upvotes: 4

Related Questions