Reputation: 43
I have a network which contains Conv2D layers followed by ReLU activations, declared as such:
x = layers.Conv2D(self.hparams['channels_count'], kernel_size=(4,1))(x)
x = layers.ReLU()(x)
And it is ported to TFLite with the following representaiton:
Basic TFLite network without Q-aware training
However, after performing quantization-aware training on the network and porting it again, the ReLU layers are now explicit in the graph:
TFLite network after Q-aware training
This results in them being processed separately on the target instead of during the evaluation of the Conv2D kernel, inducing a 10% performance loss in my overall network.
Declaring the activation with the following implicit syntax does not produce the problem:
x = layers.Conv2D(self.hparams['channels_count'], kernel_size=(4,1), activation='relu')(x)
Basic TFLite network with implicit ReLU activation
TFLite network with implicit ReLU after Q-aware training
However, this restricts the network to basic ReLU activation, whereas I would like to use ReLU6 which cannot be declared in this way.
Is this a TFLite issue? If not, is there a way to prevent the ReLU layer from being split? Or alternatively, is there a way to manually merge the ReLU layers back into the Conv2D layers after the quantization-aware training?
Edit: QA training code:
def learn_qaware(self):
quantize_model = tfmot.quantization.keras.quantize_model
self.model = quantize_model(self.model)
training_generator = SCDataGenerator(self.training_set)
validate_generator = SCDataGenerator(self.validate_set)
self.model.compile(
optimizer=self.configure_optimizers(qa_learn=True),
loss=self.get_LLP_loss(),
metrics=self.get_metrics(),
run_eagerly=config['eager_mode'],
)
self.model.fit(
training_generator,
epochs = self.hparams['max_epochs'],
batch_size = 1,
shuffle = self.hparams['shuffle_curves'],
validation_data = validate_generator,
callbacks = self.get_callbacks(qa_learn=True),
)
Quantized TFLite model generation code:
def tflite_convert(classifier):
output_file = get_tflite_filename(classifier.model_path)
# Convert the model to the TensorFlow Lite format without quantization
saved_shape = classifier.model.input.shape.as_list()
fixed_shape = saved_shape
fixed_shape[0] = 1
classifier.model.input.set_shape(fixed_shape) # Force batch size to 1 for generation
converter = tf.lite.TFLiteConverter.from_keras_model(classifier.model)
classifier.model.input.set_shape(saved_shape)
# Set the optimization flag.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Enforce integer only quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
# Provide a representative dataset to ensure we quantize correctly.
if config['eager_mode']:
tf.executing_eagerly()
def representative_dataset():
for x in classifier.validate_set.get_all_inputs():
rs = x.reshape(1, x.shape[0], 1, 1).astype(np.float32)
yield([rs])
converter.representative_dataset = representative_dataset
model_tflite = converter.convert()
# Save the model to disk
open(output_file, "wb").write(model_tflite)
return TFLite_model(output_file)
Upvotes: 1
Views: 1125
Reputation: 43
I have found a workaround which works by instantiating a non-trained version of the model, then copying over the weights from the quantization aware trained model before converting to TFLite.
This seems like quite a hack, so I'm still on the lookout for a cleaner solution.
Code for the workaround:
def dequantize(self):
if not hasattr(self, 'fp_model') or not self.fp_model:
self.fp_model = self.get_default_model()
def find_layer_in_model(name, model):
for layer in model.layers:
if layer.name == name:
return layer
return None
def find_weight_group_in_layer(name, layer):
for weight_group in quant_layer.trainable_weights:
if weight_group.name == name:
return weight_group
return None
for layer in self.fp_model.layers:
if 'input' in layer.name or 'quantize_layer' in layer.name:
continue
QUANT_TAG = "quant_"
quant_layer = find_layer_in_model(QUANT_TAG+layer.name,self.model)
if quant_layer is None:
raise RuntimeError('Failed to match layer ' + layer.name)
for i, weight_group in enumerate(layer.trainable_weights):
quant_weight_group = find_weight_group_in_layer(QUANT_TAG+weight_group.name, quant_layer)
if quant_weight_group is None:
quant_weight_group = find_weight_group_in_layer(weight_group.name, quant_layer)
if quant_weight_group is None:
raise RuntimeError('Failed to match weight group ' + weight_group.name)
layer.trainable_weights[i].assign(quant_weight_group)
self.model = self.fp_model
Upvotes: 2