Wuchen
Wuchen

Reputation: 561

How can I visualize the weights(variables) in cnn in Tensorflow?

After training the cnn model, I want to visualize the weight or print out the weights, what can I do? I cannot even print out the variables after training. Thank you!

Upvotes: 46

Views: 57108

Answers (4)

YScharf
YScharf

Reputation: 2082

Using the tensorflow 2 API, There are several options:

Weights extracted using the get_weights() function.

weights_n = model.layers[n].get_weights()[0]

Bias extracted using the numpy() convert function.

bias_n = model.layers[n].bias.numpy()

Upvotes: 0

Martin Thoma
Martin Thoma

Reputation: 136865

You can extract the values as numpy arrays the following way:

with tf.variable_scope('conv1', reuse=True) as scope_conv:
    W_conv1 = tf.get_variable('weights', shape=[5, 5, 1, 32])
    weights = W_conv1.eval()
    with open("conv1.weights.npz", "w") as outfile:
        np.save(outfile, weights)

Note that you have to adjust the scope ('conv1' in my case) and the variable name ('weights' in my case).

Then it boils down on visualizing numpy arrays. One example how to visualize numpy arrays is

#!/usr/bin/env python

"""Visualize numpy arrays."""

import numpy as np
import scipy.misc

arr = np.load('conv1.weights.npb')

# Get each 5x5 filter from the 5x5x1x32 array
for filter_ in range(arr.shape[3]):
    # Get the 5x5x1 filter:
    extracted_filter = arr[:, :, :, filter_]

    # Get rid of the last dimension (hence get 5x5):
    extracted_filter = np.squeeze(extracted_filter)

    # display the filter (might be very small - you can resize the window)
    scipy.misc.imshow(extracted_filter)

Upvotes: 7

etoropov
etoropov

Reputation: 1225

Like @mrry said, you can use tf.image_summary. For example, for cifar10_train.py, you can put this code somewhere under def train(). Note how you access a var under scope 'conv1'

# Visualize conv1 features
with tf.variable_scope('conv1') as scope_conv:
  weights = tf.get_variable('weights')

  # scale weights to [0 255] and convert to uint8 (maybe change scaling?)
  x_min = tf.reduce_min(weights)
  x_max = tf.reduce_max(weights)
  weights_0_to_1 = (weights - x_min) / (x_max - x_min)
  weights_0_to_255_uint8 = tf.image.convert_image_dtype (weights_0_to_1, dtype=tf.uint8)

  # to tf.image_summary format [batch_size, height, width, channels]
  weights_transposed = tf.transpose (weights_0_to_255_uint8, [3, 0, 1, 2])

  # this will display random 3 filters from the 64 in conv1
  tf.image_summary('conv1/filters', weights_transposed, max_images=3)

If you want to visualize all your conv1 filters in one nice grid, you would have to organize them into a grid yourself. I did that today, so now I'd like to share a gist for visualizing conv1 as a grid

Upvotes: 22

mrry
mrry

Reputation: 126194

To visualize the weights, you can use a tf.image_summary() op to transform a convolutional filter (or a slice of a filter) into a summary proto, write them to a log using a tf.train.SummaryWriter, and visualize the log using TensorBoard.

Let's say you have the following (simplified) program:

filter = tf.Variable(tf.truncated_normal([8, 8, 3]))
images = tf.placeholder(tf.float32, shape=[None, 28, 28])

conv = tf.nn.conv2d(images, filter, strides=[1, 1, 1, 1], padding="SAME")

# More ops...
loss = ...
optimizer = tf.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)

filter_summary = tf.image_summary(filter)

sess = tf.Session()
summary_writer = tf.train.SummaryWriter('/tmp/logs', sess.graph_def)
for i in range(10000):
  sess.run(train_op)
  if i % 10 == 0:
    # Log a summary every 10 steps.
    summary_writer.add_summary(filter_summary, i)

After doing this, you can start TensorBoard to visualize the logs in /tmp/logs, and you will be able to see a visualization of the filter.

Note that this trick visualizes depth-3 filters as RGB images (to match the channels of the input image). If you have deeper filters, or they don't make sense to interpret as color channels, you can use the tf.split() op to split the filter on the depth dimension, and generate one image summary per depth.

Upvotes: 37

Related Questions