Reputation: 29
I am using the ONNX-Python-library. I am trying to quantize ai-models statically using the quantize_static() function imported from onnxruntime.quantization.
This function takes a calibration_data_reader-object as the third input, and I could not find a comprehensive explanation what this object is, or how to create one.
How could I create this object? Also, would it be possible to create a generic version of the calibrationDataReader, that can be used by multiple models?
I have tried looking this up on the onnx-website and reading through examples for this on the github-page, but cannot find an explanation or documentation I can understand, as most datareaders seem to be pre-made specific to a model.
The best thing I could think of, was something like this, using a calibrator from the onnx.quantization-tools:
from onnxruntime.quantization import quantize_static, calibrate
def quantize(model_path, output_path):
calibrator = calibrate.create_calibrator(model_path, calibrate_method= calibrate.CalibrationMethod.MinMax)
quantize_static(model_input= model_path, model_output= output_path, calibration_data_reader= calibrator)
Which throws this error:
quantize_static(model_input= model_path, model_output= output_path, calibration_data_reader= calibrator)
File "████Python\Python311\Lib\site-packages\onnxruntime\quantization\quantize.py", line 435, in quantize_static
calibrator.collect_data(calibration_data_reader)
File "████Python\Python311\Lib\site-packages\onnxruntime\quantization\calibrate.py", line 301, in collect_data
inputs = data_reader.get_next()
^^^^^^^^^^^^^^^^^^^^
AttributeError: 'MinMaxCalibrater' object has no attribute 'get_next'
I am probably using the wrong type of object here, but I am out of ideas. An explanation or link to a resource explaining the topic would help me with this.
Upvotes: 3
Views: 744
Reputation: 1
I also had a problem like yours and haven't encountered any valid external resource or documentation on this.
Given onnxruntime
's calibrate.py
code, CalibrationDataReader
is a base class for an iterator-like child class that you need to implement. This child class should provide a get_next
method (returning calibration samples in a format model accepts it) and a __len__
method (returning the number of samples).
Here is the sample of such child class for a folder with images to be used for calibration:
import os
import cv2
from onnxruntime.quantization import quantize_static, quant_pre_process, CalibrationDataReader
class ImageFolderCalibrationDataReader(CalibrationDataReader):
def __init__(self, folder_path):
super().__init__()
self.folder_path = folder_path
self.calibration_images = os.listdir(folder_path)
self.current_item = 0
def get_next(self) -> dict:
"""generate the input data dict in the input format to your ONNX model"""
if self.current_item == len(self.calibration_images):
return None # None signals that the calibration is finished
image = cv2.imread(os.path.join(self.folder_path, self.calibration_images[self.current_item]))
# some image preprocessing to match model's input requirements
image = cv2.resize(image, (224, 224))
image = image.astype('float32') / 255.
self.current_item += 1
return {"input": image}
def __len__(self) -> int:
"""get length of the calibration dataset"""
return len(self.calibration_images)
if __name__ == '__main__':
# Run the quantization process
model_initial_path = "path/to/model.onnx"
model_prep_path = "path/to/model.prep.onnx"
model_quant_path = "path/to/model.quant.onnx"
calibration_data_path = "calibration_data/"
calibration_dataset = ImageFolderCalibrationDataReader(calibration_data_path)
quant_pre_process(model_initial_path, model_prep_path)
quantize_static(model_prep_path, model_quant_path, calibration_dataset)
This class is valid only if the original ONNX model accepts inputs in dictionary format, like this:
import onnxruntime as ort
model_initial_path = "path/to/model.onnx"
model = ort.InferenceSession(model_initial_path)
results = model.run(None, {"input": image})
You can check the model's expected input keys using:
print([input.name for input in model.get_inputs()])
Upvotes: 0