djbacs
djbacs

Reputation: 51

How to perform pruning with a transfer learning model?

Essentially, I want to perform pruning to my transfer learning model.

I used efficientnetb0 for classifying microorganisms.

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 32
epochs = 25

end_step = np.ceil(len(training_set) / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
                                        initial_sparsity = 0.40,                                                                 
                                        final_sparsity = 0.90,                                                                   
                                        begin_step = 0,                                                                
                                        end_step = end_step
                                        )
                  }

model_for_pruning = prune_low_magnitude(
                         efficientnetb0_transfer_model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
efficientnetb0_transfer_model_for_pruning.compile(optimizer=optim,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

efficientnetb0_transfer_model_for_pruning.summary()

However, I'm getting the following error:

ValueError: Please initialize `Prune` with a supported layer. Layers should either be supported by the PruneRegistry (built-in keras layers) or should be a `PrunableLayer` instance, or should has a customer defined `get_prunable_weights` method. You passed: <class 'tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling'>

What could I be doing wrong?

Upvotes: 0

Views: 448

Answers (1)

Yunlu Li
Yunlu Li

Reputation: 77

You're hitting this error.

The pruning API is not flexible enough. It currently expects all layers in the model to be prunable (logic here). Ideally it should be able to skip layers like image rescaling. Can you file a github issue and we'll work on a fix. Thanks!

Upvotes: 1

Related Questions