mrgloom
mrgloom

Reputation: 21682

Keras: how to load weights partially?

How to load model weights partially? For example I want to load only block1 of VGG19 model using original imagenet weights(vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5):

def VGG19_part(input_shape=None):
    img_input = tf.keras.layers.Input(shape=input_shape)

    # Block 1
    x = tf.keras.layers.Conv2D(64, (3, 3),
                      activation='linear',
                      padding='same',
                      name='block1_conv1')(img_input)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(64, (3, 3),
                      activation='linear',
                      padding='same',
                      name='block1_conv2')(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    model = tf.keras.Model(img_input, x, name='vgg19')

    model.load_weights('/Users/myuser/.keras/models/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')

    print(model.summary())

    return model

This code produce an error: ValueError: You are trying to load a weight file containing 16 layers into a model with 2 layers.

Upvotes: 2

Views: 1891

Answers (1)

Marco Cerliani
Marco Cerliani

Reputation: 22031

the vgg19 from Keras application module has by default the weights of imagenet so I use it to load the weights of our interest in our custom model

input_shape = (224,224,3)

full_vgg19 = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)

def VGG19_part(full_vgg19, input_shape=None):
    
    img_input = tf.keras.layers.Input(shape=input_shape)

    # Block 1
    x = tf.keras.layers.Conv2D(64, (3, 3),
                      activation='linear',
                      padding='same',
                      name='block1_conv1')(img_input)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(64, (3, 3),
                      activation='linear',
                      padding='same',
                      name='block1_conv2')(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    model = tf.keras.Model(img_input, x, name='vgg19')
    model.set_weights(full_vgg19.get_weights()[:4])
    
    return model

part_vgg19 = VGG19_part(full_vgg19, input_shape)

### check if the weights/bias are the same:
[(i == j).all() for i,j in zip(part_vgg19.get_weights()[:4],full_vgg19.get_weights()[:4])] # True True True True

Upvotes: 1

Related Questions