arash javanmard
arash javanmard

Reputation: 1387

Tensorflow 2 displaying a histogram of weights

I am trying to display a histogram of all network weights (CNN) at each epoch in the Tensorboard using LambdaCallback of Tensorflow 2 as follow:

def log_hist_weights(model,writer):
    model = model
    writer = writer
    
    def log_hist_weights(epoch, logs):
        # predict images
        Ws = model.get_weights()
        with writer.as_default():
            tf.summary.histogram("epoch: " + str(epoch), Ws)
    return log_hist_weights

hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))

But the problem is get_weights() returns all the network weights without any name (e.g. filter-weights, BatchNormalization weights, and other stuffs) but I am actually interested just in CNN-filter weights.

It would be great if I could implement something like this one in Tensorflow 2.

How can display a histogram of the filter-weights using Tensorflow?

Upvotes: 1

Views: 854

Answers (1)

arash javanmard
arash javanmard

Reputation: 1387

For anybody else with the same problem, here is how i finally solved it using Tensorflow 2:

def log_hist_weights(model,writer):
    model = model
    writer = writer

    def log_hist(epoch, logs):
        # predict images
        with writer.as_default():
            for tf_var in baseline_model.trainable_weights:
                    tf.summary.histogram(tf_var.name, tf_var.numpy(), step=epoch)
    return log_hist

    hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))

enter image description here

Upvotes: 2

Related Questions