Skip to content

Commit 63dd605

Browse files
authored
Fix untyped float values in quantization tool missing from PR microsoft#18043 (microsoft#19182)
### Description Extends the code coverage to Entroy, Histogram and Distribution calibration method, fix bugs while doing it. ### Motivation and Context Bugs detected in [Olive](https://github.com/microsoft/OLive).
1 parent 9876cc7 commit 63dd605

File tree

3 files changed

+131
-23
lines changed

3 files changed

+131
-23
lines changed

onnxruntime/python/tools/quantization/calibrate.py

+67-19
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# license information.
66
# --------------------------------------------------------------------------
77
import abc
8+
import copy
89
import itertools
910
import os
1011
import uuid
@@ -21,6 +22,48 @@
2122
from .quant_utils import apply_plot, load_model_with_shape_infer, smooth_distribution
2223

2324

25+
def rel_entr(pk: np.ndarray, qk: np.ndarray) -> np.ndarray:
26+
"""
27+
See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html#scipy.special.rel_entr.
28+
Python implementation.
29+
"""
30+
res = np.empty(pk.shape, dtype=pk.dtype)
31+
res[:] = pk[:] * np.log(pk[:] / qk[:])
32+
c2 = (pk == 0) & (qk >= 0)
33+
res[c2] = 0
34+
c1 = (pk > 0) & (qk > 0)
35+
res[~c1] = np.inf
36+
return res
37+
38+
39+
def entropy(
40+
pk: np.ndarray,
41+
qk: np.ndarray,
42+
base: Optional[float] = None,
43+
axis: int = 0,
44+
) -> np.ndarray:
45+
"""
46+
Simplifeied version of entropy.
47+
Source: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html.
48+
This avoids taking a dependency on scipy just for this function.
49+
"""
50+
assert base is None or base > 0, "base={base} must be a positive number or `None`."
51+
assert qk is not None, "qk is None"
52+
53+
pk = np.asarray(pk).astype(np.float32)
54+
pk = 1.0 * pk / np.sum(pk, axis=axis, keepdims=True)
55+
56+
qk = np.asarray(qk).astype(np.float32)
57+
pk, qk = np.broadcast_arrays(pk, qk)
58+
qk = 1.0 * qk / np.sum(qk, axis=axis, keepdims=True)
59+
vec = rel_entr(pk, qk)
60+
61+
s = np.sum(vec, axis=axis)
62+
if base is not None:
63+
s /= np.log(base)
64+
return s.astype(pk.dtype)
65+
66+
2467
class TensorData:
2568
_allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"])
2669
_floats = frozenset(["avg", "std", "lowest", "highest", "hist_edges"])
@@ -708,8 +751,8 @@ def collect_absolute_value(self, name_to_arr):
708751
min_value = np.min(data_arr_np)
709752
max_value = np.max(data_arr_np)
710753
else:
711-
min_value = 0
712-
max_value = 0
754+
min_value = np.array(0, dtype=data_arr_np.dtype)
755+
max_value = np.array(0, dtype=data_arr_np.dtype)
713756

714757
data_arr_np = np.absolute(data_arr_np) # only consider absolute value
715758

@@ -725,6 +768,8 @@ def collect_absolute_value(self, name_to_arr):
725768
old_histogram = self.histogram_dict[tensor]
726769
old_min = old_histogram[2]
727770
old_max = old_histogram[3]
771+
assert hasattr(old_min, "dtype"), f"old_min should be a numpy array but is {type(old_min)}"
772+
assert hasattr(old_max, "dtype"), f"old_min should be a numpy array but is {type(old_max)}"
728773
old_hist = old_histogram[0]
729774
old_hist_edges = old_histogram[1]
730775
temp_amax = np.max(data_arr_np)
@@ -757,7 +802,7 @@ def collect_value(self, name_to_arr):
757802
min_value = np.array(0, dtype=data_arr.dtype)
758803
max_value = np.array(0, dtype=data_arr.dtype)
759804

760-
threshold = max(abs(min_value), abs(max_value))
805+
threshold = np.array(max(abs(min_value), abs(max_value)), dtype=data_arr.dtype)
761806

762807
if tensor in self.histogram_dict:
763808
old_histogram = self.histogram_dict[tensor]
@@ -809,7 +854,7 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho
809854
def compute_collection_result(self):
810855
if not self.histogram_dict or len(self.histogram_dict) == 0:
811856
raise ValueError("Histogram has not been collected. Please run collect() first.")
812-
print(f"Finding optimal threshold for each tensor using {self.method} algorithm ...")
857+
print(f"Finding optimal threshold for each tensor using {self.method!r} algorithm ...")
813858

814859
if self.method == "entropy":
815860
return self.compute_entropy()
@@ -938,7 +983,14 @@ def compute_distribution(self):
938983
assert avg_coef.dtype != np.float64
939984
assert std_coef.dtype != np.float64
940985
assert hist_edges.dtype != np.float64
941-
thresholds_dict[tensor] = TensorData(avg=avg_coef, std=std_coef, hist=hist, hist_edges=hist_edges)
986+
thresholds_dict[tensor] = TensorData(
987+
avg=avg_coef,
988+
std=std_coef,
989+
hist=hist,
990+
hist_edges=hist_edges,
991+
lowest=hist_edges.min(),
992+
highest=hist_edges.max(),
993+
)
942994

943995
# Plot histogram for debug only
944996
if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"):
@@ -952,18 +1004,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
9521004
`q` is a truncated version of the original distribution.
9531005
Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
9541006
"""
955-
import copy
956-
957-
from scipy.stats import entropy
958-
9591007
hist = histogram[0]
9601008
hist_edges = histogram[1]
9611009
num_bins = hist.size
9621010
zero_bin_index = num_bins // 2
9631011
num_half_quantized_bin = num_quantized_bins // 2
9641012

1013+
dtype = histogram[1].dtype
9651014
kl_divergence = np.zeros(zero_bin_index - num_half_quantized_bin + 1)
966-
thresholds = [(0, 0) for i in range(kl_divergence.size)]
1015+
thresholds = [(np.array(0, dtype=dtype), np.array(0, dtype=dtype)) for i in range(kl_divergence.size)]
9671016

9681017
# <------------ num bins ---------------->
9691018
# <--- quantized bins ---->
@@ -983,10 +1032,7 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
9831032
start_index = zero_bin_index - i
9841033
end_index = zero_bin_index + i + 1 if (zero_bin_index + i + 1) <= num_bins else num_bins
9851034

986-
thresholds[i - num_half_quantized_bin] = (
987-
float(hist_edges[start_index]),
988-
float(hist_edges[end_index]),
989-
)
1035+
thresholds[i - num_half_quantized_bin] = (hist_edges[start_index], hist_edges[end_index])
9901036

9911037
sliced_distribution = copy.deepcopy(hist[start_index:end_index])
9921038

@@ -1020,15 +1066,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
10201066

10211067
norm = sum(nonzeros[start:end])
10221068
if norm != 0:
1023-
q[start:end] = float(quantized_bins[index]) / float(norm)
1069+
q[start:end] = quantized_bins[index] / norm
10241070

10251071
p = smooth_distribution(p)
10261072
q = smooth_distribution(q)
1027-
1028-
if isinstance(q, np.ndarray):
1029-
kl_divergence[i - num_half_quantized_bin] = entropy(p, q)
1073+
if p is None or q is None:
1074+
div = np.array(np.inf, dtype=dtype)
10301075
else:
1031-
kl_divergence[i - num_half_quantized_bin] = float("inf")
1076+
div = np.array(entropy(p, q), dtype=dtype)
1077+
kl_divergence[i - num_half_quantized_bin] = div
10321078

10331079
min_kl_divergence_idx = np.argmin(kl_divergence)
10341080
optimal_threshold = thresholds[min_kl_divergence_idx]
@@ -1038,6 +1084,8 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
10381084
optimal_threshold = (min_value, optimal_threshold[1])
10391085
if optimal_threshold[1] > max_value:
10401086
optimal_threshold = (optimal_threshold[0], max_value)
1087+
assert hasattr(optimal_threshold[0], "dtype")
1088+
assert hasattr(optimal_threshold[1], "dtype")
10411089
return optimal_threshold
10421090

10431091

onnxruntime/python/tools/quantization/quant_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def smooth_distribution(p, eps=0.0001):
653653

654654
if not n_nonzeros:
655655
# raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
656-
return -1
656+
return None
657657
eps1 = eps * float(n_zeros) / float(n_nonzeros)
658658
assert eps1 < 1.0, "n_zeros=%d, n_nonzeros=%d, eps1=%f" % (
659659
n_zeros,

onnxruntime/test/python/quantization/test_op_matmul.py

+63-3
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,39 @@
1010
import numpy as np
1111
import onnx
1212
import packaging.version as pv
13+
from numpy.testing import assert_almost_equal
1314
from onnx import TensorProto, helper
1415
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type
1516

17+
from onnxruntime.capi.onnxruntime_pybind11_state import Fail
1618
from onnxruntime.quantization import CalibrationMethod, QuantFormat, QuantType, quantize_dynamic, quantize_static
19+
from onnxruntime.quantization.calibrate import entropy
20+
21+
22+
def skip_if_new_opset_exception_raised(func):
23+
def wrapper(*args, **kwargs):
24+
try:
25+
func(*args, **kwargs)
26+
except Fail as e:
27+
if "is under development and support for this is limited" in str(e):
28+
raise unittest.SkipTest(f"Skipped {func} due to opset under development.") # noqa: B904
29+
raise
30+
31+
return wrapper
1732

1833

1934
class TestOpMatMul(unittest.TestCase):
35+
def test_entropy(self):
36+
try:
37+
from scipy.stats import entropy as scipy_entropy
38+
except ImportError:
39+
raise unittest.SkipTest("scipy not installed.") # noqa: B904
40+
pk = (np.arange(10) - 5).astype(np.float32) / 10
41+
qk = -(np.arange(10) - 5).astype(np.float32) / 10
42+
ent = scipy_entropy(pk, qk)
43+
get = entropy(pk, qk)
44+
assert_almost_equal(ent, get)
45+
2046
def input_feeds(self, n, name2shape, dtype):
2147
input_data_list = []
2248
for _i in range(n):
@@ -324,10 +350,11 @@ def test_quantize_matmul_u8u8(self):
324350
@unittest.skipIf(
325351
pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709"
326352
)
353+
@skip_if_new_opset_exception_raised
327354
def test_quantize_matmul_u8u8_f16(self):
328-
self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 19, 9)
355+
self.quantize_matmul_u8u8(onnx.TensorProto.FLOAT16, 21, 9)
329356

330-
def quantize_matmul_s8s8(self, tt, opset, ir_version):
357+
def quantize_matmul_s8s8(self, tt, opset, ir_version, calibrate_method=CalibrationMethod.MinMax):
331358
np.random.seed(1)
332359
model_fp_path = "matmul_fp.onnx"
333360
self.construct_model_matmul(model_fp_path, tensor_type=tt, opset=opset, ir_version=ir_version)
@@ -341,13 +368,15 @@ def quantize_matmul_s8s8(self, tt, opset, ir_version):
341368
activation_type=QuantType.QInt8,
342369
weight_type=QuantType.QInt8,
343370
extra_options={"ActivationSymmetric": True},
371+
calibrate_method=calibrate_method,
344372
)
345373
self.static_quant_test_qdq(
346374
model_fp_path,
347375
data_reader,
348376
activation_type=QuantType.QInt8,
349377
weight_type=QuantType.QInt8,
350378
extra_options={"ActivationSymmetric": True},
379+
calibrate_method=calibrate_method,
351380
)
352381

353382
# dynamic quantization doesn't support activation:int8
@@ -357,11 +386,42 @@ def quantize_matmul_s8s8(self, tt, opset, ir_version):
357386
def test_quantize_matmul_s8s8(self):
358387
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8)
359388

389+
def test_quantize_matmul_s8s8_entropy(self):
390+
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Entropy)
391+
392+
def test_quantize_matmul_s8s8_percentile(self):
393+
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Percentile)
394+
395+
def test_quantize_matmul_s8s8_distribution(self):
396+
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT, 18, 8, calibrate_method=CalibrationMethod.Distribution)
397+
360398
@unittest.skipIf(
361399
pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709"
362400
)
401+
@skip_if_new_opset_exception_raised
363402
def test_quantize_matmul_s8s8_f16(self):
364-
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 19, 9)
403+
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9)
404+
405+
@unittest.skipIf(
406+
pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709"
407+
)
408+
@skip_if_new_opset_exception_raised
409+
def test_quantize_matmul_s8s8_f16_entropy(self):
410+
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Entropy)
411+
412+
@unittest.skipIf(
413+
pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709"
414+
)
415+
@skip_if_new_opset_exception_raised
416+
def test_quantize_matmul_s8s8_f16_percentile(self):
417+
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Percentile)
418+
419+
@unittest.skipIf(
420+
pv.Version(onnx.__version__) < pv.Version("1.15.1"), reason="Shape inference bug, see onnx PR #5709"
421+
)
422+
@skip_if_new_opset_exception_raised
423+
def test_quantize_matmul_s8s8_f16_distribution(self):
424+
self.quantize_matmul_s8s8(onnx.TensorProto.FLOAT16, 21, 9, calibrate_method=CalibrationMethod.Distribution)
365425

366426
def quantize_matmul_e4m3fn_same(self, tt, opset, ir_version):
367427
np.random.seed(1)

0 commit comments

Comments
 (0)