user2473992
user2473992

Reputation: 117

On loading the saved Keras sequential model, my test data gives low accuracy in the beginning

I am creating a simple sequential Keras model which will take 10k inputs in a batch of 100. Each input has 3 columns and the corresponding output is sum of that row.

Sequential model has 2 layers- LSTM(Stateful=true) , Dense.

Now, after compiling and fitting the model, I am saving it in 'model.h5' file.

Then, I read the saved model, and call model.predict with a test data (size=10k , batch_size = 100).

Problem: the prediction doesn't work properly for first 400-500 inputs and for the rest its working perfectly fine with very low val_loss.

Case1: I make the LSTM layer Stateless(i.e. Stateful=False) In this case Keras is providing very accurate outputs for all the test data.

Case2: Instead of saving and then reading again, if I directly apply model.predict on the model created, all the outputs are coming accurately.

But, I need Stateful=True, also, I want to save my model and then resume work on that model later.

1.Is there any way to solve this?

2.Also, when I am providing test data, how is the model's accuracy increasing? ( because the first 400-500 tests provide inaccurate results and the rest are pretty accurate)

Upvotes: 1

Views: 788

Answers (1)

Marcin Możejko
Marcin Możejko

Reputation: 40516

Your problem seems to come from losing the hidden states of your cells. During model building they might be reset and this might cause the problem.

So (it's a little bit cumbersome), but you could save and load also a states of your network:

  1. How to save? (assuming that i-th layer is a recurrentone):

    hidden_state = model.layers[i].states[0].eval()
    cell_state = model.layers[i].states[0].eval()
    
    numpy.save("some name", hidden_state)
    numpy.save("some other name", cell_state)
    
  2. Now when you can reload the hidden state, here you can read on how to set the hidden state in a layer.

Of course - it's the best to pack all of this methods in some kind of object and e.g. class constructor methods.

Upvotes: 1

Related Questions