Black Jack 21
Black Jack 21

Reputation: 339

How to get output of intermediate Keras layers in batches?

I am not sure how to get output of an intermediate layer in Keras. I have read the other questions on stackoverflow but they seem to be functions with a single sample as input. I want to get output features(at intermediate layer) in batches as well. Here is my model:

model = Sequential()
model.add(ResNet50(include_top = False, pooling = RESNET50_POOLING_AVERAGE, weights = resnet_weights_path)) #None
model.add(Dense(784, activation = 'relu'))
model.add(Dense(NUM_CLASSES, activation = DENSE_LAYER_ACTIVATION))
model.layers[0].trainable = True

After training the model, in my code I want to get the output after the first dense layer (784 dimensional). Is this the right way to do it?

pred = model.layers[1].predict_generator(data_generator, steps = len(data_generator), verbose = 1)

I am new to Keras so I am a little unsure. Do I need to compile the model again after training?

Upvotes: 2

Views: 2652

Answers (1)

TF_Support
TF_Support

Reputation: 1836

No, you don't need to compile again after training.

Based on your Sequential model.

Layer 0 :: model.add(ResNet50(include_top = False, pooling = RESNET50_POOLING_AVERAGE, weights = resnet_weights_path)) #None
Layer 1 :: model.add(Dense(784, activation = 'relu'))
Layer 2 :: model.add(Dense(NUM_CLASSES, activation = DENSE_LAYER_ACTIVATION))

Accessing the layers, may differ if used Functional API approach.

Using Tensorflow 2.1.0, you could try this approach when you want to access intermediate outputs.

model_dense_784 = Model(inputs=model.input, outputs = model.layers[1].output)

pred_dense_784 = model_dense_784.predict(train_data_gen, steps = 1) # predict_generator is deprecated

print(pred_dense_784.shape) # Use this to check Output Shape

It is highly advisable to use the model.predict() method, rather than model.predict_generator() as it is already deprecated.
You could also use shape() method to check whether the output generated is the same as indicated on the model.summary().

Upvotes: 1

Related Questions