cheerss
cheerss

Reputation: 86

How to get intermediate output when using tf.keras.application

tf.keras.application contains many famous neural network link VGG, densenet, mobilenet and so on. Take tf.keras.application.MobileNet as an example, what I am interested in is not only the final output, but also the output of the intermediate layer, how could I get all these output when retraining the network.

May be model.get_output_at(index) helps. However, every time I call this function, I get a DeferredTensor because I cannot forward the data at the same time. Does a convenient way exists?

Thanks in advance~

Upvotes: 2

Views: 4647

Answers (1)

itsreddy
itsreddy

Reputation: 114

I suggest you to read the keras documentation:

One simple way is to create a new Model that will output the layers that you are interested in:

from keras.models import Model

model = ...  # create the original model

layer_name = 'my_layer'
intermediate_layer_model = Model(inputs=model.input,
                                 outputs=model.get_layer(layer_name).output)
intermediate_output = intermediate_layer_model.predict(data)

Alternatively, you can build a Keras function that will return the output of a certain layer given a certain input, for example:

from keras import backend as K

# with a Sequential model
get_3rd_layer_output = K.function([model.layers[0].input],
                                  [model.layers[3].output])
layer_output = get_3rd_layer_output([x])[0]

Similarly, you could build a Theano and TensorFlow function directly.

Note that if your model has a different behavior in training and testing phase (e.g. if it uses Dropout, BatchNormalization, etc.), you will need to pass the learning phase flag to your function:

get_3rd_layer_output = K.function([model.layers[0].input, K.learning_phase()],
                                  [model.layers[3].output])

# output in test mode = 0
layer_output = get_3rd_layer_output([x, 0])[0]

# output in train mode = 1
layer_output = get_3rd_layer_output([x, 1])[0]

Here is another similar answer written by fchollet himself: How can I get hidden layer representation of the given data?

Upvotes: 8

Related Questions