bhomass
bhomass

Reputation: 3572

can I reset the hidden state of an RNN between input data sets in Keras?

I am training an RNN on a large data set which consists of disparate sources. I do not want the history of one set to spill over to the next. This means I want to reset the hidden state at the end of one set, before sending in the next. How can I do that with Keras? The doc claims you can get into the low level configurations.

What I am trying to do is resetting the lstm hidden state every time a new data set is fed, so no influence from the prev dataset is carried forward. see line

prevh = Hout[t-1] if t > 0 else h0

from Karpathy's simple python implementation https://gist.github.com/karpathy/587454dc0146a6ae21fc line 45

If I find the lstm layer and call reset on it, I am worried that would wipe out the entire training of the weights and biases, not just Hout

Here is the training loop code

for iteration in range(1, 10):
    for key in X_dict:
        X = X_dict[key]
        y = y_dict[key]
        history=model.fit(X, y, batch_size=batch_size, callbacks=cbks, nb_epoch=1,verbose=0)

Each turn in the loop feeds in data from a single market. That's where I like to reset the hout in the lstm.

Upvotes: 3

Views: 1759

Answers (1)

Nassim Ben
Nassim Ben

Reputation: 11543

To reset the states of your model, call .reset_states() on either a specific layer, or on your entire model. source

So if you have a list of datasets :

for ds in datasets :
    model.reset_states()
    model.fit(ds['inputs'],ds['targets'],...)

Is that what you are looking for?

EDIT :

for iteration in range(1, 10):
    for key in X_dict:
        model.reset_states() # reset the states of all the LSTM's of your network
        #model.layers[lstm_layer_index].reset_states() # reset the states of this specific LSTM layer
        X = X_dict[key]
        y = y_dict[key]
        history=model.fit(X, y, batch_size=batch_size, callbacks=cbks, nb_epoch=1,verbose=0)

This is how you apply it.

By default, the LSTM's are not stateful. Which means that they won't keep a hidden state after going over a sequence. The initial state when starting a new sequence will be set to 0. If you selected stateful=True, then it will keep the last hidden state (the output) of the previous sequence to initialize itself for the next sequence in the batch. It's like the sequence was continuing.

Doing model.reset_states() will just reset those last hidden states that were kept in memory to 0, just like if the sequence was starting from scratch.

If you don't trust that .reset_states() to do what you expect, feel free to go to the source code.

Upvotes: 2

Related Questions