Asutosh Panda
Asutosh Panda

Reputation: 1473

What is the rule to know how many LSTM cells and how many units in each LSTM cell do you need in Keras?

I know that a LSTM cell has a number of ANNs inside.

But when defining the hidden layer for the same problem, I have seen some people using only 1 LSTM cell and others use 2, 3 LSTM cells like this -

model = Sequential()
model.add(LSTM(256, input_shape=(n_prev, 1), return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(128, input_shape=(n_prev, 1), return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(64, input_shape=(n_prev, 1), return_sequences=False))
model.add(Dropout(0.3))
model.add(Dense(1))
model.add(Activation('linear'))
  1. Is there any rule as to how many LSTM cells you should take? Or its just manual experimenting?
  2. Another question following this is, how many units you should take in an LSTM cell. Like some people take 256, some take 64 for the same problem.

Upvotes: 11

Views: 18188

Answers (1)

OverLordGoldDragon
OverLordGoldDragon

Reputation: 19776

There are no "rules", but there are guidelines; in practice, you'd experiment with depth vs. width, each of which works differently:

  • RNN width is defined by (1) # of input channels; (2) # of cell's filters (output channels/units). As with CNN, each RNN filter is an independent feature extractor: more is suited for higher-complexity information, including but not limited to: dimensionality, modality, noise, frequency.
  • RNN depth is defined by (1) # of stacked layers; (2) # of timesteps. Specifics will vary by architecture, but from information standpoint, unlike CNNs, RNNs are dense: every timestep influences the ultimate output of a layer, hence the ultimate output of the next layer - so it again isn't as simple as "more nonlinearity"; stacked RNNs exploit both spatial and temporal information.

In general, width extracts more features, whereas depth extracts richer features - but if there aren't many features to extract from given data, width should be lessened - and the "simpler" the data/problem, the less layers are suitable. Ultimately, however, it may be best to spare extensive analysis and try different combinations of each -- see this SO for more info.

Lastly, avoid Dropout and use LSTM(recurrent_dropout=...) instead (see linked SO).

Upvotes: 18

Related Questions