Zylon
Zylon

Reputation: 29

ONNX-Python: Can someone explain the Calibration_Data_Reader requested by the static_quantization-function?

I am using the ONNX-Python-library. I am trying to quantize ai-models statically using the quantize_static() function imported from onnxruntime.quantization.

This function takes a calibration_data_reader-object as the third input, and I could not find a comprehensive explanation what this object is, or how to create one.

How could I create this object? Also, would it be possible to create a generic version of the calibrationDataReader, that can be used by multiple models?

I have tried looking this up on the onnx-website and reading through examples for this on the github-page, but cannot find an explanation or documentation I can understand, as most datareaders seem to be pre-made specific to a model.

The best thing I could think of, was something like this, using a calibrator from the onnx.quantization-tools:

from onnxruntime.quantization import quantize_static, calibrate

def quantize(model_path, output_path):
    calibrator = calibrate.create_calibrator(model_path, calibrate_method= calibrate.CalibrationMethod.MinMax)
    quantize_static(model_input= model_path, model_output= output_path, calibration_data_reader= calibrator)

Which throws this error:

quantize_static(model_input= model_path, model_output= output_path, calibration_data_reader= calibrator)
File "████Python\Python311\Lib\site-packages\onnxruntime\quantization\quantize.py", line 435, in quantize_static
calibrator.collect_data(calibration_data_reader)
File "████Python\Python311\Lib\site-packages\onnxruntime\quantization\calibrate.py", line 301, in collect_data
inputs = data_reader.get_next()
         ^^^^^^^^^^^^^^^^^^^^
AttributeError: 'MinMaxCalibrater' object has no attribute 'get_next'

I am probably using the wrong type of object here, but I am out of ideas. An explanation or link to a resource explaining the topic would help me with this.

Upvotes: 3

Views: 744

Answers (1)

Andrii Shevtsov
Andrii Shevtsov

Reputation: 1

I also had a problem like yours and haven't encountered any valid external resource or documentation on this.

Given onnxruntime's calibrate.py code, CalibrationDataReader is a base class for an iterator-like child class that you need to implement. This child class should provide a get_next method (returning calibration samples in a format model accepts it) and a __len__ method (returning the number of samples).

Here is the sample of such child class for a folder with images to be used for calibration:

import os
import cv2
from onnxruntime.quantization import quantize_static, quant_pre_process, CalibrationDataReader


class ImageFolderCalibrationDataReader(CalibrationDataReader):
    def __init__(self, folder_path):
        super().__init__()

        self.folder_path = folder_path
        self.calibration_images = os.listdir(folder_path)
        self.current_item = 0

    def get_next(self) -> dict:
        """generate the input data dict in the input format to your ONNX model"""
        if self.current_item == len(self.calibration_images):
            return None  # None signals that the calibration is finished

        image = cv2.imread(os.path.join(self.folder_path, self.calibration_images[self.current_item]))
        # some image preprocessing to match model's input requirements
        image = cv2.resize(image, (224, 224))
        image = image.astype('float32') / 255.

        self.current_item += 1

        return {"input": image}

    def __len__(self) -> int:
        """get length of the calibration dataset"""
        return len(self.calibration_images)


if __name__ == '__main__':
    # Run the quantization process
    model_initial_path = "path/to/model.onnx"
    model_prep_path = "path/to/model.prep.onnx"
    model_quant_path = "path/to/model.quant.onnx"

    calibration_data_path = "calibration_data/"
    calibration_dataset = ImageFolderCalibrationDataReader(calibration_data_path)

    quant_pre_process(model_initial_path, model_prep_path)
    quantize_static(model_prep_path, model_quant_path, calibration_dataset)

This class is valid only if the original ONNX model accepts inputs in dictionary format, like this:

import onnxruntime as ort

model_initial_path = "path/to/model.onnx"
model = ort.InferenceSession(model_initial_path)
results = model.run(None, {"input": image})

You can check the model's expected input keys using:

print([input.name for input in model.get_inputs()])

Upvotes: 0

Related Questions