Ophir Yoktan
Ophir Yoktan

Reputation: 8449

How to split an existing keras model into two separate models?

I have a keras model (already trained) that I want to split to two parts (one part computes internal representation from the original inputs, and the other part computes the output from the precomputed internal representation)

getting the 1st part is simple (input to internal representation) but the second part is problematic

I found two related answers, but they are problematic in my case

the methods described in these answers are:

  1. How to split a model into two seperate models?

in this solution you redefine the second part of the network - this appears feasible, but requires a significant amount of code duplication (the network is quite complex)

  1. How do I split an convolutional autoencoder?

in this solution the model is defined by composition of two models - this looks like a good solution, but not relevant for an existing trained network

Upvotes: 2

Views: 3595

Answers (3)

Shawn
Shawn

Reputation: 611

You can use the following function to split a model

from keras.layers import Input
from keras.models import Model
def get_bottom_top_model(model, layer_name):
    layer = model.get_layer(layer_name)
    bottom_input = Input(model.input_shape[1:])
    bottom_output = bottom_input
    top_input = Input(layer.output_shape[1:])
    top_output = top_input

    bottom = True
    for layer in model.layers:
        if bottom:
            bottom_output = layer(bottom_output)
        else:
            top_output = layer(top_output)
        if layer.name == layer_name:
            bottom = False

    bottom_model = Model(bottom_input, bottom_output)
    top_model = Model(top_input, top_output)

    return bottom_model, top_model
bottom_model, top_model = get_bottom_top_model(model, "dense_1")

Layer_name is just the name of the layer that you want to split at.

Upvotes: 1

Amine Sehaba
Amine Sehaba

Reputation: 120

Here is my solution (only for Sequential models). I have used this with MobileNet2 and it worked perfectly for me, just call the function and give the pre-trained model and the index where you want to split and it will return two splitted models:

    def split_keras_model(model, index):
      '''
      Input: 
        model: A pre-trained Keras Sequential model
        index: The index of the layer where we want to split the model
      Output:
        model1: From layer 0 to index
        model2: From index+1 layer to the output of the original model 
      The index layer will be the last layer of the model_1 and the same shape of that layer will be the input layer of the model_2
      '''
      # Creating the first part...
      # Get the input layer shape
      layer_input_1 = Input(model.layers[0].input_shape[1:])
      # Initialize the model with the input layer
      x = layer_input_1
      # Foreach layer: connect it to the new model
      for layer in model.layers[1:index]:
            x = layer(x)
      # Create the model instance
      model1 = Model(inputs=layer_input_1, outputs=x)


      # Creating the second part...
      # Get the input shape of desired layer
      input_shape_2 = model.layers[index].get_input_shape_at(0)[1:] 
      print("Input shape of model 2: "+str(input_shape_2))
      # A new input tensor to be able to feed the desired layer
      layer_input_2 = Input(shape=input_shape_2) 

      # Create the new nodes for each layer in the path
      x = layer_input_2
      # Foreach layer connect it to the new model
      for layer in model.layers[index:]:
          x = layer(x)

      # create the model
      model2 = Model(inputs=layer_input_2, outputs=x)

      return (model1, model2)

Upvotes: 2

Ophir Yoktan
Ophir Yoktan

Reputation: 8449

The best solution that I found:

  1. define a "nested" model (a composition of sub models) - a suggested in this answer

  2. make sure the layer names correspond to the layer names in the old model - this is the important part, as it makes the layer mapping simpler

  3. copy the weights from the old model to the new one - like in this example:

    for sub_model in filter(lambda l: isinstance(l, keras.models.Model), new_model.model.layers):
      for layer in filter(lambda l: l.weights, sub_model.layers):
        layer.set_weights(original_model.model.get_layer(layer.name).get_weights())
    

Upvotes: 3

Related Questions