gab
gab

Reputation: 307

How can I output some data during a model.fit() run in tensorflow?

I would like to print the value and/or the shape of a tensor during a model.fit() run and not before. In PyTorch I can just put a print(input.shape) statement into the model.forward() function.

Is there something similar in TensorFlow?

Upvotes: 1

Views: 811

Answers (1)

Reactgular
Reactgular

Reputation: 54791

You can pass a callback object to the model.fit() method and then perform actions at different stages during fitting.

https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))


model = get_model()
model.fit(x_train, y_train, callbacks=[MyCustomCallback()])

https://www.tensorflow.org/guide/keras/custom_callback

Upvotes: 2

Related Questions