adamconkey
adamconkey

Reputation: 4745

Does model.reset_states for LSTM affect any other non-LSTM layers in the model?

I am using the Stateful mode of LSTMs in tf.keras where I need to manually do reset_states when I have processed my sequence data, as described here. It seems that normally people do model.reset_states(), but in my case my LSTM layer is embedded in a much more complex network that includes all kinds of other layers like Dense, Conv, and so forth. My question is, if I just call model.reset_states() on my main model that has an LSTM embedded in it (and only one LSTM), should I be worried about that reset affecting other layers in the model such as the Dense or Conv layers? Would it be better to hunt down the LSTM layer and isolate the reset_states call to just that layer?

Upvotes: 3

Views: 443

Answers (2)

thushv89
thushv89

Reputation: 11333

TLDR: Layers like LSTM/GRU have weights and states, where layers like Conv/Dense/Embedding have only weights. reset_state() only affects layers with states.

What reset_states() does is that for an LSTM it resets the c_t and h_t outputs in the layer. These are the values you normally obtain by setting LSTM(n, return_state=True).

Embedding, Dense, Conv layers don't have such states in them. So model.reset_states() will not affect those kind of feed forward layers. Just the sequential layers like LSTMs and GRUs.

If you like you can have a look at the source code and verify that this function looks if each layer has a reset_state attribute in it (which feed forward layers don't have).

Upvotes: 2

OverLordGoldDragon
OverLordGoldDragon

Reputation: 19816

Any layer with a settable stateful attribute is subject to reset_states(); the method iterates over each layer, checks whether it has stateful=True - if so, calls its reset_states() method; see source.

In Keras, all recurrent layers, including ConvLSTM2D, have a settable stateful attribute - I'm not aware of any other. tensorflow.keras, however, has plenty of custom layer implementations that may; you can use code below to check for sure:

def print_statefuls(model):
    for layer in model.layers:
        if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
            print(layer.name, "is stateful")

Upvotes: 2

Related Questions