Nima Mousavi
Nima Mousavi

Reputation: 1661

Keras redefine input shape

Let's say I want to train a GRU and because I need stateful=true the batch-size has to be known beforehand.

Using the functional API I would have an Input as follows:

input_1 = Input(batch_shape=(batch_size, None, features))

But when I evaluate the model I don't want to pass my test data in batches (batch_size = 1; predictions for one observation) with fixed timesteps. My solution at the moment is to load the saved model and rebuild it with:

input_1 = Input(shape=(None, num_input_dim))

To do that though I need a method that goes through every layer of the model and then set the weights afterwards.

    input_1 = Input(shape=(None, num_input_dim))
    x1 = input_1
    weights = []
    for l in range(0, len(layers)):
        if isinstance(layers[l], keras.layers.GRU):
            x1 = GRU(layers[l].output_shape[-1], return_sequences=True)(x1)
            weights.append(layers[l].get_weights())
        elif isinstance(layers[l], keras.layers.Dense):
            x1 = Dense(layers[l].output_shape[-1], activation='tanh')(x1)
            weights.append(layers[l].get_weights())
        else:
            continue

(This is just an example and I find this solution very unelegant.)

There must be a better way to redefine the input shape. Can somebody help me out here please.

Upvotes: 1

Views: 217

Answers (1)

Daniel Möller
Daniel Möller

Reputation: 86630

Since you're not using a stateful=True model for evaluating, then you do need to redefine the model.

You can make a function to create the model taking the options as input:

def createModel(stateful, weights=None):

    #input
    if (stateful==True):
        batch = batch_size
    else:
        batch = None

    #You don't need fixed timesteps, even if the model is stateful
    input_1 = Input(batch_shape=(batch_size, None, num_input_dim))

    #layer creation as you did with your first model
    ...
    out = LSTM(...., stateful=stateful)(someInput)
    ...

    model = Model(input_1,out)

    if weights is not None:
        model.set_weights(weights)

    return model

Work sequence:

#create the training model 

trainModel = createModel(True,None)

 #train
 ...

#create the other model
newModel = createModel(False,trainModel.get_weights())

Upvotes: 1

Related Questions