Pandas
Pandas

Reputation: 61

ModelCheckpoint monitoring values when the model has multiple outputs

My model has two outputs, I want to monitor one to save my model. Below is part of my code. The version of TensorFlow is 2.0

model = MobileNetBaseModel()()
model.compile(optimizer=tf.keras.optimizers.Adam(),
              metrics={"pitch_yaw_roll": "mae"},
              loss={"pitch_yaw_roll": compute_mse_loss, # or "mse"
                    "total_logits": compute_cross_entropy_loss(num_classes=num_classes)},
              loss_weights= {"pitch_yaw_roll":mse_weight, "total_logits":cross_entropy_weight})
file_path = os.path.join(checkpoint_path, "model.{epoch:2d}-{val_loss:.2f}.h5")
tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
                                   monitor="val_loss",
                                   verbose=1,
                                   save_freq=save_freq,
                                   save_best_only=True)

The default monitor='val_loss' in the ModelCheckpoint callback, how do I choose what I need? I want to monitor {"pitch_yaw_roll": "mae"}.

Upvotes: 3

Views: 1936

Answers (2)

euh
euh

Reputation: 451

Just adding to comment above, I belive your checkpoint doesn't work because of incorrect name of value to monitor. General, solution here might be to have a peak into history that your fit creates.

history = model.fit(...)
pd.DataFrame(history.history)

there you will find names of metrics you should use in monitor statement.

Upvotes: 0

bluesummers
bluesummers

Reputation: 12607

If you want ModelCheckpoint to save according to another metric value, use the key of that metric in the .compile(metrics={...}, ...) metrics dictionary.

So for example, if you would like to save only the best "pitch_yaw_roll" epoch result (best being the minimum value) you should use

tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
                                   monitor="val_pitch_yaw_roll",
                                   verbose=1,
                                   mode="min",
                                   save_freq=save_freq,
                                   save_best_only=True)

If you opt for "pitch_yaw_roll" instead of "val_pitch_yaw_roll" it will save according to the training loss and not according to the validation loss

Upvotes: 2

Related Questions