Reputation: 1387
I am trying to display a histogram of all network weights (CNN) at each epoch in the Tensorboard using LambdaCallback
of Tensorflow 2 as follow:
def log_hist_weights(model,writer):
model = model
writer = writer
def log_hist_weights(epoch, logs):
# predict images
Ws = model.get_weights()
with writer.as_default():
tf.summary.histogram("epoch: " + str(epoch), Ws)
return log_hist_weights
hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))
But the problem is get_weights()
returns all the network weights without any name (e.g. filter-weights,
BatchNormalization
weights, and other stuffs) but I am actually interested just in CNN-filter weights.
It would be great if I could implement something like this one in Tensorflow 2.
How can display a histogram of the filter-weights using Tensorflow?
Upvotes: 1
Views: 854
Reputation: 1387
For anybody else with the same problem, here is how i finally solved it using Tensorflow 2:
def log_hist_weights(model,writer):
model = model
writer = writer
def log_hist(epoch, logs):
# predict images
with writer.as_default():
for tf_var in baseline_model.trainable_weights:
tf.summary.histogram(tf_var.name, tf_var.numpy(), step=epoch)
return log_hist
hist_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_hist_weights(baseline_model, file_writer))
Upvotes: 2