Reputation: 19836
I've come across research publications and Q&A's discussing a need for inspecting RNN weights; some related answers are in the right direction, suggesting get_weights()
- but how do I actually visualize the weights meaningfully? Namely, LSTMs and GRUs have gates, and all RNNs have channels that serve as independent feature extractors - so how do I (1) fetch per-gate weights, and (2) plot them in an informative manner?
Upvotes: 4
Views: 7121
Reputation: 19836
Keras/TF build RNN weights in a well-defined order, which can be inspected from the source code or via layer.__dict__
directly - then to be used to fetch per-kernel and per-gate weights; per-channel treatment can then be employed given a tensor's shape. Below code & explanations cover every possible case of a Keras/TF RNN, and should be easily expandable to any future API changes.
Also see visualizing RNN gradients, and an application to RNN regularization; unlike in the former post, I won't be including a simplified variant here, as it'd still be rather large and complex per the nature of weight extraction and organization; instead, simply view relevant source code in the repository (see next section).
Code source: See RNN (this post included w/ bigger images), my repository; included are:
from keras
& from tf.keras
Visualization methods:
EX 1: uni-LSTM, 256 units, weights -- batch_shape = (16, 100, 20)
(input)
rnn_histogram(model, 'lstm', equate_axes=False, show_bias=False)
rnn_histogram(model, 'lstm', equate_axes=True, show_bias=False)
rnn_heatmap(model, 'lstm')
equate_axes=True
for an even comparison across kernels and gates, improving quality of comparison, but potentially degrading visual appealEX 2: bi-CuDNNLSTM, 256 units, weights -- batch_shape = (16, 100, 16)
(input)
rnn_histogram(model, 'bidir', equate_axes=2)
rnn_heatmap(model, 'bidir', norm=(-.8, .8))
CuDNNLSTM
(and CuDNNGRU
) biases are defined and initialized differently - something that can't be inferred from histogramsEX 3: uni-CuDNNGRU, 64 units, weights gradients -- batch_shape = (16, 100, 16)
(input)
rnn_heatmap(model, 'gru', mode='grads', input_data=x, labels=y, cmap=None, absolute_value=True)
absolute_value=True
and a greyscale colormapNew
is the most active kernel gate (input-to-hidden), suggesting more error correction on permitting information flowReset
is the least active recurrent gate (hidden-to-hidden), suggesting least error correction on memory-keepingBONUS EX: LSTM NaN detection, 512 units, weights -- batch_shape = (16, 100, 16)
(input)
Upvotes: 8