Ankish Bansal
Ankish Bansal

Reputation: 1900

How to split resnet50 model from top as well as from bottom?

I am using keras pretrained model with include_top=False, but i also want to remove one resblock from top and one from bottom as well. For vgg net, it is simple, because of straight forward links in layers, but in resnet, architecture is complicated because of skip connection, so direct approach doesn't fit well.

Can somebody recommend any resource or scripts to do it?

renet = tf.keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet')

Upvotes: 2

Views: 18070

Answers (1)

Adria Ciurana
Adria Ciurana

Reputation: 934

If you do not understand correctly, you want to eliminate the first block and the last one.

My advice is to use resnet.summary () to be able to visualize all the names of the model. Or even better if you have a tensorboard to see the relationships clearly.

Although you can know the completion of a block in Residual Network is a sum and just followed an activation. Activation will be the layer you want to obtain.

The names of the blocks are similar to res2a ... The number 2 indicates the block and the letter the "subblock".

Based on the Resnet50 architecture:

enter image description here

If I am looking to remove the first residual block, I must look for the end of res2c. In this case I found this:

activation_57 (Activation) (None, 56, 56, 64) 0 bn2c_branch2a [0] [0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_57 [0] [0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2b [0] [0]
__________________________________________________________________________________________________
activation_58 (Activation) (None, 56, 56, 64) 0 bn2c_branch2b [0] [0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_58 [0] [0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2c_branch2c [0] [0]
__________________________________________________________________________________________________
add_19 (Add) (None, 56, 56, 256) 0 bn2c_branch2c [0] [0]
                                                                 activation_56 [0] [0]
__________________________________________________________________________________________________
activation_59 (Activation) (None, 56, 56, 256) 0 add_19 [0] [0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_59 [0] [0]

The input layer is the res3a_branch2a. This form I jump the first block of residuals.

activation_87 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_87[0][0]              
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_88 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_88[0][0]              
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_29 (Add)                    (None, 14, 14, 1024) 0           bn4f_branch2c[0][0]              
                                                                 activation_86[0][0]              
__________________________________________________________________________________________________
activation_89 (Activation)      (None, 14, 14, 1024) 0           add_29[0][0]   

If I am looking to remove the last block of residuals, I should look for the end of res4. Thaat is activation_89.

Making these cuts we would have this model:

enter image description here

resnet_cut = Model(inputs=resnet.get_layer('res3a_branch2a'), outputs=resnet.get_layer('activation_89'))

Upvotes: 3

Related Questions