Raghuram Vadapalli
Raghuram Vadapalli

Reputation: 1240

Accessing RNN weights- Tensorflow

I am using tf.python.ops.rnn_cell.GRUCell

output, state = tf.nn.dynamic_rnn(
        GRUCell(HID_DIM),
        sequence,
        dtype=tf.float32,
        sequence_length=length(sequence)
)

How do I get the weights of this GRUCell. I need to see them for debugging.

Upvotes: 6

Views: 1090

Answers (1)

Filipe Aleixo
Filipe Aleixo

Reputation: 4244

The values of all the variables in the current session can be printed using:

with tf.Session() as sess:
    variables_names =[v.name for v in tf.trainable_variables()]
    values = sess.run(variables_names)
    for k,v in zip(variables_names, values):
        print(k, v)

Upvotes: 2

Related Questions