Reputation: 3249
I am trying to understand the tensorflow.keras.layers.SimpleRNN by building a simple digits classifier. The digits of Mnist dataset are of size 28X28. So the main idea is to present each line of the image in a time t. I have seem this idea in some blogs, for instance, this one, where it presents this image:
So my RNN is like this:
units=128
self.model = Sequential()
self.model.add(layers.SimpleRNN(128, input_shape=(28,28)))
self.model.add(Dense(self.output_size, activation='softmax'))
I know that RNN is defined using the following equations:
Parâmetros:
W={w_{hh},w_{xh}} and V={v}.
input vector: x_t.
Update equations:
h_t=f(w_{hh} h_{t-1}+w_{xh} x_t).
y = v h_t.
Questions:
What is exacly "units=128" defining? Is the number of neurons of W_hh, w_xh? Is there anyplace where I can find this information?
If I run self.model.summary()
I get
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn (SimpleRNN) (None, 128) 20096
_________________________________________________________________
dense_35 (Dense) (None, 10) 1290
=================================================================
Total params: 21,386
Trainable params: 21,386
Non-trainable params: 0
_________________________
How do I go from the number of units to these numbers of parameters "20096" and "1290"?
Upvotes: 1
Views: 1021
Reputation: 641
Units is the number of neurons, which is the dimensionality of the output for that layer. This information can be found at the documentation.
The number of parameters are dependent on the layer input and the number of units. For the SimpleRNN layer this is 128 * 128 + 128 * 28 + 128 = 20096 (see this answer). For the dense layer this is 128 * 10 + 10 = 1290. These 10 and 128 that are added are because of the bias weights in the layer, which is turned on by default.
input_shape = (28, 28) means that your network will handle inputs of size 28x28 data points. Since the first dimension is the batch dimension, it will handle 28 vectors of length 28 (as depicted in your image). Inputs of a variable length are usually split up to fit in the given input_shape. If an input is smaller than the input_shape, padding can be applied to make it fit.
Upvotes: 2