Bruce
Bruce

Reputation: 31

Steps of tf.summary.* operations in TensorBoard are always 0

When I'm training my model with TensorFlow 2.3, I want to visualize some intermediate tensors calculated using the weight in the computation graph of my customized tf.keras.layers.Layer.

So I use tf.summary.image() to record these tensors and visualize them as images like this:

class CustomizedLayer(tf.keras.layers.Layer):
    def call(self, inputs, training=None):
        # ... some code ...
        tf.summary.image(name="some_weight_map", data=some_weight_map)
        # ... some code ...

But in TensorBoard, no matter how many steps passed, there is only one image of step 0 shown.

And I tried to set the parameter step of tf.summary.image() to the value obtained from tf.summary.experimental.get_step():

tf.summary.image(name="weight_map", data=weight_map, step=tf.summary.experimental.get_step())

And update the step by calling tf.summary.experimental.set_step from a customized Callback using a tf.Variable like codes shown below:

class SummaryCallback(tf.keras.callbacks.Callback):
def __init__(self, step_per_epoch):
    super().__init__()
    self.global_step = tf.Variable(initial_value=0, trainable=False, name="global_step")
    self.global_epoch = 0
    self.step_per_epoch = step_per_epoch
    tf.summary.experimental.set_step(self.global_step)

def on_batch_end(self, batch, logs=None):
    self.global_step = batch + self.step_per_epoch * self.global_epoch
    tf.summary.experimental.set_step(self.global_step)  
    # whether the line above is commented, calling tf.summary.experimental.get_step() in computation graph code always returns 0.
    # tf.print(self.global_step)

def on_epoch_end(self, epoch, logs=None):
    self.global_epoch += 1

This Callback's instance is passed in the argument callbacks in model.fit() function.

But the value tf.summary.experimental.get_step() returned is still 0.

The TensorFlow document of "tf.summary.experimental.set_step()" says:

when using this with @tf.functions, the step value will be captured at the time the function is traced, so changes to the step outside the function will not be reflected inside the function unless using a tf.Variable step.

Accroding to the document, I am already using a Variable to store the steps, but it's changes are still not reflected inside the function (or keras.Model).

Note: My code produces expected results in TensorFlow 1.x with just a simple line of tf.summary.image() before I migrate it to TensorFlow 2.

So I want to know if my approach is wrong in TensorFlow 2?

In TF2, how can I get training steps inside the computation graph?

Or there is other solution to summarize tensors (as scalar, image, etc.) inside a model in TensorFlow 2?

Upvotes: 1

Views: 1087

Answers (1)

Bruce
Bruce

Reputation: 31

I found this issue has been reported on Github repository of Tensorflow: https://github.com/tensorflow/tensorflow/issues/43568

This is caused by using tf.summary in model while tf.keras.callbacks.TensorBoard callback is also enabled, and the step will always be zero. The issue reporter gives a temporary solution.

To fix it, inherit the tf.keras.callbacks.TensorBoard class and overwrite the on_train_begin method and on_test_begin method like this:

class TensorBoardFix(tf.keras.callbacks.TensorBoard):
"""
This fixes incorrect step values when using the TensorBoard callback with custom summary ops
"""

def on_train_begin(self, *args, **kwargs):
    super(TensorBoardFix, self).on_train_begin(*args, **kwargs)
    tf.summary.experimental.set_step(self._train_step)


def on_test_begin(self, *args, **kwargs):
    super(TensorBoardFix, self).on_test_begin(*args, **kwargs)
    tf.summary.experimental.set_step(self._val_step)

And use this fixed callback class in model.fit():

tensorboard_callback = TensorBoardFix(log_dir=log_dir, histogram_freq=1, write_graph=True, update_freq=1)
model.fit(dataset, epochs=200, callbacks=[tensorboard_callback])

This solve my problem and now I can get proper step inside my model by calling tf.summary.experimental.get_step().

(This issue may be fixed in later version of TensorFlow)

Upvotes: 2

Related Questions