E_learner
E_learner

Reputation: 3582

Tensorflow model pruning gives 'nan' for training and validation losses

I'm trying to prune a base model that consists of several layers on top of a VGG network. It also contains a user-defined layer named instance_normalization. For pruning to be successful, I've defined the get_prunable_weights function of this layer as follows:

### defined for model pruning
    def get_prunable_weights(self):
        return self.weights

I used the following function to obtain a to-be-pruned model structure using a base model named model:

def define_prune_model(self, model, img_shape, epochs, batch_size, validation_split=0.1):
        num_images = img_shape[0] * (1 - validation_split)
        end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

        # Define model for pruning.
        pruning_params = {
            'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
                                                                    final_sparsity=0.80,
                                                                    begin_step=0,
                                                                    end_step=end_step)
        }

        model_for_pruning = prune_low_magnitude(model, **pruning_params)

        model_for_pruning.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])

        model_for_pruning.summary()

        return model_for_pruning

Then, I wrote the following function to perform training on this pruning model:

def train_prune_model(self, model_for_pruning, train_images, train_labels,
                     epochs, batch_size, validation_split=0.1):
    callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./models/pruned'),
    ]
    model_for_pruning.fit(train_images, train_labels,
                batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                callbacks=callbacks)
    return model_for_pruning

However, when training, I found out that the training and validation losses were all nan, and the final model prediction output was totally zero. However, the base model that passed to define_prune_model has successfully trained and predicted correctly.

How can I solve this? Thank you in advance.

Upvotes: 0

Views: 287

Answers (1)

It is difficult to pinpoint the issue without more informations. In particular, can you please give more detail (preferably as code) about your custom instance_normalization layer ?

Assuming that the code is fine: Since you mentioned that the model trains correctly without pruning, could it be that those pruning parameters are too harsh ? After all, those options set 50% of the weights to zero right from the first learning step.

Here is what I would try:

  • Experiment with a lower level of sparsity (especially initial_sparsity).
  • Start to apply pruning later during the training (begin_step argument of the pruning schedule). Some even prefer to train the model once without applying pruning at all. Then re-train again with prune_low_magnitude().
  • Only prune at some steps, giving time for the model to recover between prunings (frequency argument).
  • Finally should it still fail, the usual cures when encountering nan losses: reduce the learning rate, use regularization or gradient clipping, ...

Upvotes: 1

Related Questions