Reputation: 3582
I'm trying to prune a base model that consists of several layers on top of a VGG network. It also contains a user-defined layer named instance_normalization
. For pruning to be successful, I've defined the get_prunable_weights
function of this layer as follows:
### defined for model pruning
def get_prunable_weights(self):
return self.weights
I used the following function to obtain a to-be-pruned model structure using a base model named model
:
def define_prune_model(self, model, img_shape, epochs, batch_size, validation_split=0.1):
num_images = img_shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
final_sparsity=0.80,
begin_step=0,
end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
model_for_pruning.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_for_pruning.summary()
return model_for_pruning
Then, I wrote the following function to perform training on this pruning model:
def train_prune_model(self, model_for_pruning, train_images, train_labels,
epochs, batch_size, validation_split=0.1):
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir='./models/pruned'),
]
model_for_pruning.fit(train_images, train_labels,
batch_size=batch_size, epochs=epochs, validation_split=validation_split,
callbacks=callbacks)
return model_for_pruning
However, when training, I found out that the training and validation losses were all nan
, and the final model prediction output was totally zero. However, the base model that passed to define_prune_model
has successfully trained and predicted correctly.
How can I solve this? Thank you in advance.
Upvotes: 0
Views: 287
Reputation: 61
It is difficult to pinpoint the issue without more informations. In particular, can you please give more detail (preferably as code) about your custom instance_normalization
layer ?
Assuming that the code is fine: Since you mentioned that the model trains correctly without pruning, could it be that those pruning parameters are too harsh ? After all, those options set 50%
of the weights to zero right from the first learning step.
Here is what I would try:
initial_sparsity
).begin_step
argument of the pruning schedule). Some even prefer to train the model once without applying pruning at all. Then re-train again with prune_low_magnitude()
.frequency
argument).Upvotes: 1