Reputation: 1273
I would like to know how to access the number of batch per epoch from inside a Keras lambda callback, that is, the value passed to the parameter steps_per_epoch
of the model.fit
function.
Below is my custom callback:
(I want to fill the ???????
in batch_per_epoch = ???????
)
class MyBatchLogger(keras.callbacks.Callback):
def __init__(self):
super().__init__()
self._current_epoch = 0
def on_epoch_begin(self, epoch, logs=None):
self._current_epoch = epoch
def on_epoch_end(self, epoch, logs=None):
print("Epoch end", logs)
def on_batch_end(self, batch, logs={}):
batch_per_epoch = ???????
acc = logs["acc"].item()
loss = logs["loss"].item()
mae = logs["mean_absolute_error"].item()
ca = logs["categorical_accuracy"].item()
print(json.dumps({
"timestamp": datetime.now().isoformat(),
"epoch": self._current_epoch,
"batch": batch,
"batchPerEpoch": batch_per_epoch,
"accuracy": acc,
"meanAbsoluteError": mae,
"categoricalAccuracy": ca,
"loss": loss
}))
I'm using Keras 2.2.5 with Tensorflow 1.14.1 but I'm OK to update if necessary.
Upvotes: 3
Views: 1167
Reputation: 56
the answer might come a bit late but I've spent a bit of time digging for it so maybe that's helpful anyway.
The information you need is here
self.params.get('steps')
Upvotes: 4