Reputation: 983
I am trying to perform transfer learning on ResNet50 model pretrained on Imagenet weights for PASCAL VOC 2012 dataset. As it is a multi label dataset, I am using sigmoid
activation function in the final layer and binary_crossentropy
loss. The metrics are precision,recall and accuracy
. Below is the code I used to build the model for 20 classes (PASCAL VOC has 20 classes).
img_height,img_width = 128,128
num_classes = 20
#If imagenet weights are being loaded,
#input must have a static square shape (one of (128, 128), (160, 160), (192, 192), or (224, 224))
base_model = applications.resnet50.ResNet50(weights= 'imagenet', include_top=False, input_shape= (img_height,img_width,3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
#x = Dropout(0.7)(x)
predictions = Dense(num_classes, activation= 'sigmoid')(x)
model = Model(inputs = base_model.input, outputs = predictions)
for layer in model.layers[-2:]:
for layer in model.layers[:-3]:
adam = Adam(lr=0.0001)
model.compile(optimizer= adam, loss='binary_crossentropy', metrics=['accuracy',precision_m,recall_m])
X_train, X_test, Y_train, Y_test = train_test_split(x_train, y, random_state=42, test_size=0.2)
savingcheckpoint = ModelCheckpoint('ResnetTL.h5',monitor='val_loss',verbose=1,save_best_only=True,mode='min')
earlystopcheckpoint = EarlyStopping(monitor='val_loss',patience=10,verbose=1,mode='min',restore_best_weights=True), Y_train, epochs=epochs, validation_data=(X_test,Y_test), batch_size=batch_size,callbacks=[savingcheckpoint,earlystopcheckpoint],shuffle=True)
It ran for 35 epochs until earlystopping and the metrics are as follows (without Dropout layer):
loss: 0.1195 - accuracy: 0.9551 - precision_m: 0.8200 - recall_m: 0.5420 - val_loss: 0.3535 - val_accuracy: 0.8358 - val_precision_m: 0.0583 - val_recall_m: 0.0757
Even with Dropout layer, I don't see much difference.
loss: 0.1584 - accuracy: 0.9428 - precision_m: 0.7212 - recall_m: 0.4333 - val_loss: 0.3508 - val_accuracy: 0.8783 - val_precision_m: 0.0595 - val_recall_m: 0.0403
With dropout, for a few epochs, the model is reaching to a validation precision and accuracy of 0.2 but not above that.
I see that precision and recall of validation set is pretty low compared to training set with and without dropout layer. How should I interpret this? Does this mean the model is overfitting. If so, what should I do? As of now the model predictions are quite random (totally incorrect). The dataset size is 11000 images.
Upvotes: 1
Views: 1047
Reputation: 2552
In this example, you froze all layers except by the last two (Global Average Pooling and the last Dense one). There is a cleaner way to achieve the same state:
rn50 = applications.resnet50.ResNet50(weights='imagenet', include_top=False,
input_shape=(img_height, img_width, 3))
x = rn50.output
x = GlobalAveragePooling2D()(x)
predictions = Dense(num_classes, activation= 'sigmoid')(x)
model = Model(inputs = base_model.input, outputs = predictions)
rn50.trainable = False # <- this
In this case, features are being extracted from the ResNet50 network and fed to a linear softmax classifier, but the ResNet50's weights are not being trained. This is called feature extraction, not fine-tuning.
The only weights being trained are from your classifier, which was instantiated with weights drawn from a random distribution, and thus should be entirely trained. You should be using Adam with its default learning rate:
So you can train it for a few epochs, and, once it's done, then you unfreeze the backbone and "fine-tune" it:
backbone.trainable = False
backbone.trainable = True
model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.00001)), initial_epoch=50)
There is a nice article about this on Keras website:
Upvotes: 0
Reputation: 986
Please can you modify code as below and try to execute
predictions = Dense(num_classes, activation= 'sigmoid')(x)
predictions = Dense(num_classes, activation= 'softmax')(x)
model.compile(optimizer= adam, loss='binary_crossentropy', metrics=['accuracy',precision_m,recall_m])
model.compile(optimizer= adam, loss='categorical_crossentropy', metrics=['accuracy',precision_m,recall_m])
Upvotes: 1