|
3 | 3 | # Licensed under the MIT License. See License.txt in the project root for
|
4 | 4 | # license information.
|
5 | 5 | # --------------------------------------------------------------------------
|
| 6 | +from __future__ import annotations |
| 7 | + |
| 8 | +import copy |
6 | 9 | import logging
|
7 | 10 | import tempfile
|
8 | 11 | from pathlib import Path
|
9 |
| -from typing import Union |
| 12 | +from typing import Any, Callable |
10 | 13 |
|
11 | 14 | import onnx
|
12 | 15 |
|
13 | 16 | from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator
|
14 | 17 | from .onnx_quantizer import ONNXQuantizer
|
15 | 18 | from .qdq_quantizer import QDQQuantizer
|
16 | 19 | from .quant_utils import (
|
| 20 | + MODEL_SIZE_THRESHOLD, |
17 | 21 | QuantFormat,
|
18 | 22 | QuantizationMode,
|
19 | 23 | QuantType,
|
|
22 | 26 | save_and_reload_model_with_shape_infer,
|
23 | 27 | )
|
24 | 28 | from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry
|
| 29 | +from .tensor_quant_overrides import TensorQuantOverridesHelper |
25 | 30 |
|
26 | 31 |
|
27 | 32 | class QuantConfig:
|
@@ -213,6 +218,163 @@ def __init__(
|
213 | 218 | self.extra_options = extra_options or {}
|
214 | 219 |
|
215 | 220 |
|
| 221 | +def get_qdq_config( |
| 222 | + model_input: str | Path | onnx.ModelProto, |
| 223 | + calibration_data_reader: CalibrationDataReader, |
| 224 | + calibrate_method=CalibrationMethod.MinMax, |
| 225 | + calibrate_args: dict[str, Any] | None = None, |
| 226 | + activation_type=QuantType.QUInt8, |
| 227 | + weight_type=QuantType.QInt8, |
| 228 | + activation_symmetric: bool = False, |
| 229 | + weight_symmetric: bool | None = None, |
| 230 | + per_channel: bool = False, |
| 231 | + keep_removable_activations: bool = False, |
| 232 | + min_real_range: float | None = None, |
| 233 | + tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None, |
| 234 | + nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None, |
| 235 | + extra_options: dict | None = None, |
| 236 | +) -> StaticQuantConfig: |
| 237 | + """ |
| 238 | + Returns a configuration suitable that quantizes the entire model to integer precision. |
| 239 | +
|
| 240 | + Params: |
| 241 | + model_input: Path to the input model file or ModelProto. |
| 242 | + calibration_data_reader: Calibration data reader. |
| 243 | + calibrate_methode: The calibration method. Defaults to MinMax. |
| 244 | + activation_type: The default activation quantization type. Defaults to QUInt8. |
| 245 | + weight_type: The default weight quantization type. Defaults to QUInt8. |
| 246 | + activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default. |
| 247 | + Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uint16, |
| 248 | + the zero-point values are 127 and 32,767, respectively. |
| 249 | + weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default. |
| 250 | + Defaults to None. If set to None, weight_symmetric is assumed true if a weight's quant type is a signed int. |
| 251 | + per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel. |
| 252 | + Defaults to false. Alternatively, use the tensor-level `tensor_quant_overrides` to select individual operators |
| 253 | + and their quantization axes. |
| 254 | + keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not |
| 255 | + be removed, and will be explicitly represented in the QDQ model. If false, these activations |
| 256 | + are automatically removed if activations are asymmetrically quantized. Keeping these activations |
| 257 | + is necessary if optimizations or EP transformations will later remove |
| 258 | + QuantizeLinear/DequantizeLinear operators from the model. |
| 259 | + min_real_range: Default is None. If set to a floating-point value, the calculation of the quantization parameters |
| 260 | + (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin) |
| 261 | + is less than the specified minimum range, rmax will be set to rmin + min_real_range. |
| 262 | + tensor_quant_overrides: tensor-level quantization overrides. Defaults to None. |
| 263 | + The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list |
| 264 | + contains a single dictionary. For per-channel quantization, the list contains either a dictionary for |
| 265 | + each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis' |
| 266 | + key must be present in the first dictionary for per-channel quantization. |
| 267 | +
|
| 268 | + Each dictionary contains optional overrides with the following keys and values. |
| 269 | + 'quant_type' = QuantType : The tensor's quantization data type. |
| 270 | + 'axis' = Int : The per-channel axis. Must be present for per-channel weights. |
| 271 | + 'scale' = Float : The scale value to use. Must also specify `zero_point` if set. |
| 272 | + 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set. |
| 273 | + 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also |
| 274 | + set `scale` or `zero_point`. |
| 275 | + 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also |
| 276 | + set `scale` or `zero_point`. Only valid for initializers. |
| 277 | + 'rmax' = Float : Override the maximum real tensor value in calibration data. |
| 278 | + Invalid if also set `scale` or `zero_point`. |
| 279 | + 'rmin' = Float : Override the minimum real tensor value in calibration data. |
| 280 | + Invalid if also set `scale` or `zero_point`. |
| 281 | + 'convert' = Dict : A nested dictionary with the same keys for an activation |
| 282 | + tensor that should be converted to another quantization type. |
| 283 | + 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation, |
| 284 | + other nodes get the original type. If not specified, |
| 285 | + assume all consumer nodes get the converted type. |
| 286 | + nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that |
| 287 | + accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto |
| 288 | + should be excluded from quantization. |
| 289 | + extra_options: Additional options specified as string key/value pairs. Refer to the documentation for |
| 290 | + `quantize_static` for valid keys and values. |
| 291 | +
|
| 292 | + Returns: |
| 293 | + A StaticQuantConfig object |
| 294 | + """ |
| 295 | + q16_types = {QuantType.QInt16, QuantType.QUInt16} |
| 296 | + q4_types = {QuantType.QInt4, QuantType.QUInt4} |
| 297 | + op_types_to_exclude = {"Cast", "DequantizeLinear", "QuantizeLinear"} |
| 298 | + |
| 299 | + model = ( |
| 300 | + model_input |
| 301 | + if isinstance(model_input, onnx.ModelProto) |
| 302 | + else onnx.load_model(model_input, load_external_data=False) |
| 303 | + ) |
| 304 | + |
| 305 | + op_types = set() |
| 306 | + model_has_external_data = False |
| 307 | + overrides_helper = TensorQuantOverridesHelper( |
| 308 | + copy.deepcopy(tensor_quant_overrides) if tensor_quant_overrides else {} |
| 309 | + ) |
| 310 | + |
| 311 | + # check if the model has external data. |
| 312 | + for initializer in model.graph.initializer: |
| 313 | + if onnx.external_data_helper.uses_external_data(initializer): |
| 314 | + model_has_external_data = True |
| 315 | + |
| 316 | + final_nodes_to_exclude = [] |
| 317 | + if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list): |
| 318 | + final_nodes_to_exclude.extend(nodes_to_exclude) |
| 319 | + |
| 320 | + # Iterate through nodes to get all operator types in the model and |
| 321 | + # call user's function to filter out nodes from quantization. |
| 322 | + for node in model.graph.node: |
| 323 | + op_types.add(node.op_type) |
| 324 | + if nodes_to_exclude is not None and callable(nodes_to_exclude): |
| 325 | + if nodes_to_exclude(model, node): |
| 326 | + final_nodes_to_exclude.append(node.name) |
| 327 | + |
| 328 | + final_extra_options = { |
| 329 | + "MinimumRealRange": min_real_range, |
| 330 | + "QDQKeepRemovableActivations": keep_removable_activations, |
| 331 | + "ActivationSymmetric": activation_symmetric, |
| 332 | + "WeightSymmetric": weight_symmetric, |
| 333 | + "ForceQuantizeNoInputCheck": True, |
| 334 | + "TensorQuantOverrides": overrides_helper.get_dict(), |
| 335 | + } |
| 336 | + |
| 337 | + # Pass along known calibration options |
| 338 | + if calibrate_args: |
| 339 | + calib_extra_options_keys = [ |
| 340 | + ("symmetric", "CalibTensorRangeSymmetric"), |
| 341 | + ("moving_average", "CalibMovingAverage"), |
| 342 | + ("averaging_constant", "CalibMovingAverageConstant"), |
| 343 | + ("max_intermediate_outputs", "CalibMaxIntermediateOutputs"), |
| 344 | + ("percentile", "CalibPercentile"), |
| 345 | + ] |
| 346 | + calib_extra_options = { |
| 347 | + key: calibrate_args.get(name) for (name, key) in calib_extra_options_keys if name in calibrate_args |
| 348 | + } |
| 349 | + final_extra_options.update(calib_extra_options) |
| 350 | + |
| 351 | + # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain |
| 352 | + # on Q/DQ operators if using 16-bit or 4-bit quantization. |
| 353 | + onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") |
| 354 | + if onnx_opset.version < 21: |
| 355 | + opset21_types = q16_types.union(q4_types) |
| 356 | + overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types()) |
| 357 | + if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types: |
| 358 | + final_extra_options["UseQDQContribOps"] = True |
| 359 | + |
| 360 | + # Allow user's extra_options to override our final_extra_options. |
| 361 | + if extra_options: |
| 362 | + final_extra_options.update(extra_options) |
| 363 | + |
| 364 | + return StaticQuantConfig( |
| 365 | + calibration_data_reader, |
| 366 | + calibrate_method=calibrate_method, |
| 367 | + quant_format=QuantFormat.QDQ, |
| 368 | + activation_type=activation_type, |
| 369 | + weight_type=weight_type, |
| 370 | + op_types_to_quantize=list(op_types.difference(op_types_to_exclude)), |
| 371 | + nodes_to_exclude=final_nodes_to_exclude, |
| 372 | + per_channel=per_channel, |
| 373 | + use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), |
| 374 | + extra_options=final_extra_options, |
| 375 | + ) |
| 376 | + |
| 377 | + |
216 | 378 | class DynamicQuantConfig(QuantConfig):
|
217 | 379 | def __init__(
|
218 | 380 | self,
|
@@ -290,8 +452,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua
|
290 | 452 |
|
291 | 453 |
|
292 | 454 | def quantize_static(
|
293 |
| - model_input: Union[str, Path, onnx.ModelProto], |
294 |
| - model_output: Union[str, Path], |
| 455 | + model_input: str | Path | onnx.ModelProto, |
| 456 | + model_output: str | Path, |
295 | 457 | calibration_data_reader: CalibrationDataReader,
|
296 | 458 | quant_format=QuantFormat.QDQ,
|
297 | 459 | op_types_to_quantize=None,
|
@@ -473,6 +635,7 @@ def quantize_static(
|
473 | 635 | ("CalibMovingAverage", "moving_average"),
|
474 | 636 | ("CalibMovingAverageConstant", "averaging_constant"),
|
475 | 637 | ("CalibMaxIntermediateOutputs", "max_intermediate_outputs"),
|
| 638 | + ("CalibPercentile", "percentile"), |
476 | 639 | ]
|
477 | 640 | calib_extra_options = {
|
478 | 641 | key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options
|
@@ -590,8 +753,8 @@ def inc_dataloader():
|
590 | 753 |
|
591 | 754 |
|
592 | 755 | def quantize_dynamic(
|
593 |
| - model_input: Union[str, Path, onnx.ModelProto], |
594 |
| - model_output: Union[str, Path], |
| 756 | + model_input: str | Path | onnx.ModelProto, |
| 757 | + model_output: str | Path, |
595 | 758 | op_types_to_quantize=None,
|
596 | 759 | per_channel=False,
|
597 | 760 | reduce_range=False,
|
@@ -690,8 +853,8 @@ def quantize_dynamic(
|
690 | 853 |
|
691 | 854 |
|
692 | 855 | def quantize(
|
693 |
| - model_input: Union[str, Path, onnx.ModelProto], |
694 |
| - model_output: Union[str, Path], |
| 856 | + model_input: str | Path | onnx.ModelProto, |
| 857 | + model_output: str | Path, |
695 | 858 | quant_config: QuantConfig,
|
696 | 859 | ):
|
697 | 860 | """Quantize a model with QuantConfig.
|
|
0 commit comments