Reputation: 21682
How to load model weights partially? For example I want to load only block1
of VGG19
model using original imagenet weights(vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
):
def VGG19_part(input_shape=None):
img_input = tf.keras.layers.Input(shape=input_shape)
# Block 1
x = tf.keras.layers.Conv2D(64, (3, 3),
activation='linear',
padding='same',
name='block1_conv1')(img_input)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(64, (3, 3),
activation='linear',
padding='same',
name='block1_conv2')(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
model = tf.keras.Model(img_input, x, name='vgg19')
model.load_weights('/Users/myuser/.keras/models/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')
print(model.summary())
return model
This code produce an error: ValueError: You are trying to load a weight file containing 16 layers into a model with 2 layers.
Upvotes: 2
Views: 1891
Reputation: 22031
the vgg19 from Keras application module has by default the weights of imagenet so I use it to load the weights of our interest in our custom model
input_shape = (224,224,3)
full_vgg19 = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
def VGG19_part(full_vgg19, input_shape=None):
img_input = tf.keras.layers.Input(shape=input_shape)
# Block 1
x = tf.keras.layers.Conv2D(64, (3, 3),
activation='linear',
padding='same',
name='block1_conv1')(img_input)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(64, (3, 3),
activation='linear',
padding='same',
name='block1_conv2')(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
model = tf.keras.Model(img_input, x, name='vgg19')
model.set_weights(full_vgg19.get_weights()[:4])
return model
part_vgg19 = VGG19_part(full_vgg19, input_shape)
### check if the weights/bias are the same:
[(i == j).all() for i,j in zip(part_vgg19.get_weights()[:4],full_vgg19.get_weights()[:4])] # True True True True
Upvotes: 1