Reputation: 3572
I am training an RNN on a large data set which consists of disparate sources. I do not want the history of one set to spill over to the next. This means I want to reset the hidden state at the end of one set, before sending in the next. How can I do that with Keras? The doc claims you can get into the low level configurations.
What I am trying to do is resetting the lstm hidden state every time a new data set is fed, so no influence from the prev dataset is carried forward. see line
prevh = Hout[t-1] if t > 0 else h0
from Karpathy's simple python implementation https://gist.github.com/karpathy/587454dc0146a6ae21fc line 45
If I find the lstm layer and call reset on it, I am worried that would wipe out the entire training of the weights and biases, not just Hout
Here is the training loop code
for iteration in range(1, 10):
for key in X_dict:
X = X_dict[key]
y = y_dict[key]
history=model.fit(X, y, batch_size=batch_size, callbacks=cbks, nb_epoch=1,verbose=0)
Each turn in the loop feeds in data from a single market. That's where I like to reset the hout in the lstm.
Upvotes: 3
Views: 1759
Reputation: 11543
To reset the states of your model, call
.reset_states()
on either a specific layer, or on your entire model. source
So if you have a list of datasets :
for ds in datasets :
model.reset_states()
model.fit(ds['inputs'],ds['targets'],...)
Is that what you are looking for?
EDIT :
for iteration in range(1, 10):
for key in X_dict:
model.reset_states() # reset the states of all the LSTM's of your network
#model.layers[lstm_layer_index].reset_states() # reset the states of this specific LSTM layer
X = X_dict[key]
y = y_dict[key]
history=model.fit(X, y, batch_size=batch_size, callbacks=cbks, nb_epoch=1,verbose=0)
This is how you apply it.
By default, the LSTM's are not stateful. Which means that they won't keep a hidden state after going over a sequence. The initial state when starting a new sequence will be set to 0. If you selected stateful=True
, then it will keep the last hidden state (the output) of the previous sequence to initialize itself for the next sequence in the batch. It's like the sequence was continuing.
Doing model.reset_states()
will just reset those last hidden states that were kept in memory to 0, just like if the sequence was starting from scratch.
If you don't trust that .reset_states()
to do what you expect, feel free to go to the source code.
Upvotes: 2