MoneyBall
MoneyBall

Reputation: 2563

How to get summary information on tensorflow RNN

I implemented a simple RNN using tensorflow, shown below:

cell = tf.contrib.rnn.BasicRNNCell(state_size)
cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)

rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, batch_size, dypte=tf.float32)

This works fine. But I'd like to log the weight variables to summary writer. Is there any way to do this?

By the way, do we use tf.nn.rnn_cell.BasicRNNCell or tf.contrib.rnn.BasicRNNCell? Or are they identical?

Upvotes: 5

Views: 770

Answers (2)

amin__
amin__

Reputation: 1058

Maxim's answer is great. I found another approach useful for me where you don't have to provide names of weight variables. This approach uses an optimizer object and compute_gradients method.

Say, you calculate the "loss" after having (outputs, states) from dynamic_rnn call. Now get an optimizer of your choice. Say Adam,

optzr = tf.train.AdamOptimizer(learning_rate)
grads_and_vars = optzr.compute_gradients(loss) 

"grads_and_vars" is A list of (gradient, variable) pairs. Now by iterating "grads_and_vars" you can have all the weights/biases and corresponding gradients if any. Like,

for grad, vars in grads_and_vars:
    print (vars, vars.name)
    tf.summary.histogram(vars.name, vars)

Ref: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer#compute_gradients

Upvotes: 2

Maxim
Maxim

Reputation: 53758

But I'd like to log the weight variables to summary writer. Is there any way to do this?

You can get a variable via tf.get_variable() function. tf.summary.histogram accepts the tensor instance, so it'd be easier to use Graph.get_tensor_by_name():

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)

with tf.variable_scope('rnn', reuse=True):
  print(tf.get_variable('basic_rnn_cell/kernel'))

kernel = tf.get_default_graph().get_tensor_by_name('rnn/basic_rnn_cell/kernel:0')
tf.summary.histogram('kernel', kernel)

By the way, do we use tf.nn.rnn_cell.BasicRNNCell or tf.contrib.rnn.BasicRNNCell? Or are they identical?

Yes, they are synonyms, but I prefer to use tf.nn.rnn_cell package, because everything in tf.contrib is sort of experimental and can be changed in 1.x versions.

Upvotes: 4

Related Questions