Reputation: 11
we can set tf.keras.callbacks.ModelCheckpoint()
, then pass a callbacks
argument to fit()
method to save the best modelcheckpoint, but how to make the same thing in a custom training loop?
Upvotes: 1
Views: 1559
Reputation: 25
You can use CallbackList:
cp_callback = ModelCheckpoint(filepath=checkpoint_path,
monitor='val_loss',
save_weights_only=True,
save_best_only=True,
mode='auto',
save_freq='epoch',
verbose=1)
callbacks = CallbackList(_callbacks, add_history=True, model=model)
logs = {}
callbacks.on_train_begin(logs=logs)
optimizer = Adam(lr=init_lr, beta_1=0.9, beta_2=0.999, clipvalue=1.0)
loss_fn = BinaryCrossentropy(from_logits=False)
train_loss_tracker = tf.keras.metrics.Mean()
val_loss_tracker = tf.keras.metrics.Mean()
train_acc_metric = tf.keras.metrics.BinaryAccuracy()
val_acc_metric = tf.keras.metrics.BinaryAccuracy()
@tf.function(experimental_relax_shapes=True)
def train_step(x, y):
with tf.GradientTape() as tape:
logits = self.net(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, self.net.trainable_weights)
optimizer.apply_gradients(zip(grads, self.net.trainable_weights))
train_loss_tracker.update_state(loss_value)
train_acc_metric.update_state(y, logits)
return {"train_loss": train_loss_tracker.result(), "train_accuracy": train_acc_metric.result()}
@tf.function(experimental_relax_shapes=True)
def val_step(x, y):
val_logits = self.net(x, training=False)
val_loss = loss_fn(y, val_logits)
val_loss_tracker.update_state(val_loss)
val_acc_metric.update_state(y, val_logits)
return {"val_loss": val_loss_tracker.result(), "val_accuracy": val_acc_metric.result()}
for epoch in range(args.max_epoch):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(train_gen):
callbacks.on_batch_begin(step, logs=logs)
callbacks.on_train_batch_begin(step, logs=logs)
train_dict = train_step(x_batch_train, np.expand_dims(y_batch_train, axis=0))
logs["train_loss"] = train_dict["train_loss"]
callbacks.on_train_batch_end(step, logs=logs)
callbacks.on_batch_end(step, logs=logs)
if step % 100 == 0:
print("Training loss (for one batch) at step %d: %.4f" % (step, float(logs["train_loss"])))
train_acc = train_acc_metric.result()
print("Training acc over epoch: %.4f" % (float(train_acc),))
train_acc_metric.reset_states()
train_loss_tracker.reset_states()
for step, (x_batch_val, y_batch_val) in enumerate(val_gen):
callbacks.on_batch_begin(step, logs=logs)
callbacks.on_test_batch_begin(step, logs=logs)
val_step(x_batch_val, np.expand_dims(y_batch_val, axis=0))
callbacks.on_test_batch_end(step, logs=logs)
callbacks.on_batch_end(step, logs=logs)
logs["val_loss"] = val_loss_tracker.result()
val_acc = val_acc_metric.result()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time)
val_acc_metric.reset_states()
val_loss_tracker.reset_states()
callbacks.on_epoch_end(epoch, logs=logs)
callbacks.on_train_end(logs=logs)
Upvotes: 1